optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +12 -4
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/modeling_base.py +12 -7
- optimum/rbln/transformers/modeling_attention_utils.py +4 -4
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
- optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
- optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -88,8 +88,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
88
88
|
def setup_runtime(self):
|
|
89
89
|
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
90
|
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
if self.rbln_config.use_position_ids:
|
|
92
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
93
|
+
else:
|
|
94
|
+
dec_attn_mask = torch.zeros(
|
|
95
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
|
|
96
|
+
)
|
|
93
97
|
|
|
94
98
|
common_kwargs = {
|
|
95
99
|
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
@@ -97,12 +101,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
97
101
|
"dec_attn_mask": dec_attn_mask,
|
|
98
102
|
"page_table_manager": page_table_manager,
|
|
99
103
|
"rbln_config": self.rbln_config,
|
|
104
|
+
"config": self.config,
|
|
100
105
|
}
|
|
101
106
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
102
107
|
runtime=self.model[0],
|
|
103
108
|
phase="prefill",
|
|
104
109
|
batch_size=self.rbln_config.batch_size,
|
|
105
|
-
|
|
110
|
+
logits_last_dim=self.logits_last_dim,
|
|
106
111
|
**common_kwargs,
|
|
107
112
|
)
|
|
108
113
|
if self.can_generate():
|
|
@@ -119,12 +124,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
119
124
|
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
120
125
|
|
|
121
126
|
@property
|
|
122
|
-
def
|
|
123
|
-
return
|
|
124
|
-
1,
|
|
125
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
-
self.config.hidden_size,
|
|
127
|
-
)
|
|
127
|
+
def logits_last_dim(self):
|
|
128
|
+
return self.config.hidden_size
|
|
128
129
|
|
|
129
130
|
@classmethod
|
|
130
131
|
def get_quantized_model(
|
|
@@ -340,10 +341,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
340
341
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
341
342
|
model_config: PretrainedConfig,
|
|
342
343
|
):
|
|
343
|
-
num_attention_heads = getattr(model_config, "n_head", None) or
|
|
344
|
+
num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
|
|
344
345
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
345
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or
|
|
346
|
-
hidden_size = getattr(model_config, "n_embd", None) or
|
|
346
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
|
|
347
|
+
hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
|
|
347
348
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
348
349
|
is_prefill = query_length > 1
|
|
349
350
|
|
|
@@ -439,10 +440,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
439
440
|
# Returns:
|
|
440
441
|
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
441
442
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
443
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
444
|
+
sliding_window_layers = []
|
|
445
|
+
|
|
446
|
+
for i in range(model_config.num_hidden_layers):
|
|
447
|
+
if hasattr(model_config, "layer_types"):
|
|
448
|
+
if model_config.layer_types[i] == "sliding_attention":
|
|
449
|
+
sliding_window_layers.append(i)
|
|
450
|
+
else:
|
|
451
|
+
sliding_window_layers.append(i)
|
|
452
|
+
|
|
453
|
+
rbln_config.sliding_window_layers = sliding_window_layers
|
|
454
|
+
|
|
455
|
+
rbln_config.cache_impl = (
|
|
456
|
+
"sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
|
|
445
457
|
)
|
|
458
|
+
return rbln_config
|
|
446
459
|
|
|
447
460
|
@classmethod
|
|
448
461
|
def _update_attention_config(
|
|
@@ -525,8 +538,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
525
538
|
if rbln_config.max_seq_len is None:
|
|
526
539
|
raise ValueError("`max_seq_len` should be specified.")
|
|
527
540
|
|
|
528
|
-
|
|
529
|
-
|
|
541
|
+
layer_types = getattr(model_config, "layer_types", None)
|
|
542
|
+
all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
|
|
543
|
+
|
|
544
|
+
if (
|
|
545
|
+
getattr(model_config, "sliding_window", None) is not None
|
|
546
|
+
and getattr(model_config, "use_sliding_window", True)
|
|
547
|
+
and not all_full_attention
|
|
530
548
|
):
|
|
531
549
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
532
550
|
if rbln_config.sliding_window is not None:
|
|
@@ -602,6 +620,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
602
620
|
input_ids: Optional[torch.LongTensor] = None,
|
|
603
621
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
604
622
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
623
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
624
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
625
|
+
output_hidden_states: Optional[bool] = None,
|
|
605
626
|
**kwargs,
|
|
606
627
|
) -> BaseModelOutputWithPast:
|
|
607
628
|
"""
|
|
@@ -623,24 +644,50 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
623
644
|
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
624
645
|
)
|
|
625
646
|
|
|
647
|
+
output_hidden_states = (
|
|
648
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
649
|
+
)
|
|
650
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
651
|
+
raise ValueError(
|
|
652
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
653
|
+
f"Please compile again with the correct argument."
|
|
654
|
+
)
|
|
655
|
+
|
|
626
656
|
all_last_hidden_states = []
|
|
657
|
+
all_hidden_states = (
|
|
658
|
+
tuple(
|
|
659
|
+
torch.zeros(
|
|
660
|
+
self.rbln_config.batch_size,
|
|
661
|
+
inputs.shape[1],
|
|
662
|
+
self.config.hidden_size,
|
|
663
|
+
dtype=self.rbln_config.torch_dtype,
|
|
664
|
+
)
|
|
665
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
666
|
+
)
|
|
667
|
+
if output_hidden_states
|
|
668
|
+
else None
|
|
669
|
+
)
|
|
627
670
|
for b_idx in range(self.rbln_config.batch_size):
|
|
628
671
|
query_length = (
|
|
629
672
|
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
630
673
|
)
|
|
631
674
|
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
632
|
-
|
|
633
|
-
inputs[b_idx : b_idx + 1],
|
|
675
|
+
outputs = self.prefill_decoder(
|
|
676
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
677
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
634
678
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
679
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
635
680
|
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
636
681
|
cache_position=cache_position,
|
|
637
682
|
batch_idx=b_idx,
|
|
638
|
-
)
|
|
639
|
-
all_last_hidden_states.append(
|
|
683
|
+
)
|
|
684
|
+
all_last_hidden_states.append(outputs.logits)
|
|
685
|
+
if self.rbln_config.output_hidden_states:
|
|
686
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
687
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
640
688
|
|
|
641
689
|
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
642
|
-
|
|
643
|
-
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
690
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
|
|
644
691
|
|
|
645
692
|
|
|
646
693
|
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
@@ -666,12 +713,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
666
713
|
auto_model_class = AutoModelForCausalLM
|
|
667
714
|
|
|
668
715
|
@property
|
|
669
|
-
def
|
|
670
|
-
return
|
|
671
|
-
1,
|
|
672
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
673
|
-
self.config.vocab_size,
|
|
674
|
-
)
|
|
716
|
+
def logits_last_dim(self):
|
|
717
|
+
return self.config.vocab_size
|
|
675
718
|
|
|
676
719
|
@classmethod
|
|
677
720
|
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
@@ -736,6 +779,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
736
779
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
737
780
|
lora_int_ids: Optional[torch.Tensor] = None,
|
|
738
781
|
return_dict: Optional[torch.Tensor] = None,
|
|
782
|
+
output_hidden_states: Optional[bool] = None,
|
|
739
783
|
**kwargs,
|
|
740
784
|
) -> Tuple[torch.FloatTensor]:
|
|
741
785
|
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
|
@@ -759,6 +803,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
759
803
|
)
|
|
760
804
|
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
761
805
|
|
|
806
|
+
output_hidden_states = (
|
|
807
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
808
|
+
)
|
|
809
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
810
|
+
raise ValueError(
|
|
811
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
812
|
+
f"Please compile again with the correct argument."
|
|
813
|
+
)
|
|
814
|
+
|
|
762
815
|
# Prefill
|
|
763
816
|
if cache_position is None:
|
|
764
817
|
logits = []
|
|
@@ -774,19 +827,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
774
827
|
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
775
828
|
)
|
|
776
829
|
|
|
830
|
+
all_hidden_states = (
|
|
831
|
+
tuple(
|
|
832
|
+
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
|
|
833
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
834
|
+
)
|
|
835
|
+
if self.rbln_config.output_hidden_states
|
|
836
|
+
else None
|
|
837
|
+
)
|
|
777
838
|
for b_idx in range(batch_size):
|
|
778
839
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
779
|
-
|
|
840
|
+
outputs = self.prefill_decoder(
|
|
780
841
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
781
842
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
782
843
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
844
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
783
845
|
cache_position=cache_position,
|
|
784
846
|
batch_idx=b_idx,
|
|
785
847
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
786
848
|
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
787
849
|
)
|
|
788
|
-
padded_cache_lengths[b_idx] +=
|
|
789
|
-
logits.append(
|
|
850
|
+
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
|
|
851
|
+
logits.append(outputs.logits)
|
|
852
|
+
if self.rbln_config.output_hidden_states:
|
|
853
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
854
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
790
855
|
logits = torch.cat(logits, dim=0)
|
|
791
856
|
# Decoder
|
|
792
857
|
else:
|
|
@@ -807,17 +872,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
807
872
|
f"or `max_length` in the generation config."
|
|
808
873
|
)
|
|
809
874
|
|
|
810
|
-
|
|
875
|
+
outputs = self.decoders[batch_size](
|
|
811
876
|
input_ids=input_ids,
|
|
812
877
|
inputs_embeds=inputs_embeds,
|
|
813
878
|
cache_position=cache_position,
|
|
814
879
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
815
880
|
lora_int_ids=lora_int_ids,
|
|
816
|
-
)
|
|
881
|
+
)
|
|
882
|
+
logits = outputs.logits
|
|
883
|
+
all_hidden_states = outputs.hidden_states
|
|
817
884
|
|
|
818
885
|
if not return_dict:
|
|
819
|
-
return logits, generate_idx, padded_cache_lengths
|
|
886
|
+
return logits, generate_idx, padded_cache_lengths, all_hidden_states
|
|
820
887
|
else:
|
|
821
888
|
return RBLNDecoderOnlyOutput(
|
|
822
|
-
logits=logits,
|
|
889
|
+
logits=logits,
|
|
890
|
+
generate_idx=generate_idx,
|
|
891
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
892
|
+
hidden_states=all_hidden_states,
|
|
823
893
|
)
|
|
@@ -64,6 +64,7 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
64
64
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
65
65
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
66
66
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
67
|
+
output_hidden_states: Optional[bool] = None,
|
|
67
68
|
):
|
|
68
69
|
# retrieve input_ids and inputs_embeds
|
|
69
70
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -96,7 +97,10 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
96
97
|
|
|
97
98
|
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
98
99
|
|
|
100
|
+
all_hidden_states = () if output_hidden_states else None
|
|
99
101
|
for layer_idx, layer in enumerate(self.layers):
|
|
102
|
+
if output_hidden_states:
|
|
103
|
+
all_hidden_states += (hidden_states,)
|
|
100
104
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
101
105
|
hidden_states = layer(
|
|
102
106
|
hidden_states=hidden_states,
|
|
@@ -110,7 +114,9 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
110
114
|
)
|
|
111
115
|
|
|
112
116
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
113
|
-
|
|
117
|
+
if output_hidden_states:
|
|
118
|
+
all_hidden_states += (hidden_states,)
|
|
119
|
+
return hidden_states, all_hidden_states
|
|
114
120
|
|
|
115
121
|
|
|
116
122
|
class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
@@ -16,7 +16,7 @@ from typing import Optional
|
|
|
16
16
|
import rebel
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from ...modeling_outputs import
|
|
19
|
+
from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
|
|
20
20
|
from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
|
|
21
21
|
from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
|
|
22
22
|
|
|
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
26
26
|
super().__init__(*args, **kwargs)
|
|
27
27
|
self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
|
|
28
28
|
self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
|
|
29
|
-
self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
|
|
30
29
|
|
|
31
30
|
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
32
31
|
(
|
|
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
106
105
|
)
|
|
107
106
|
|
|
108
107
|
step = 0
|
|
108
|
+
output_logits = []
|
|
109
|
+
all_hidden_states = [] if self.rbln_config.output_hidden_states else None
|
|
109
110
|
while step < query_length:
|
|
110
111
|
if self.rbln_config.use_image_prefill:
|
|
111
112
|
# Check if the prefill chunk is an image prefill
|
|
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
146
147
|
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
147
148
|
|
|
148
149
|
if is_image_prefill:
|
|
149
|
-
|
|
150
|
+
outputs = self.image_prefill(
|
|
150
151
|
input_chunk,
|
|
151
152
|
cache_pos_chunk,
|
|
152
153
|
block_tables,
|
|
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
157
158
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
158
159
|
)
|
|
159
160
|
else:
|
|
160
|
-
|
|
161
|
+
outputs = self.prefill(
|
|
161
162
|
input_chunk,
|
|
162
163
|
cache_pos_chunk,
|
|
163
164
|
block_tables,
|
|
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
168
169
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
169
170
|
)
|
|
170
171
|
|
|
172
|
+
if self.rbln_config.output_hidden_states:
|
|
173
|
+
output_logits.append(outputs[0])
|
|
174
|
+
all_hidden_states.append(tuple(outputs[1:]))
|
|
175
|
+
else:
|
|
176
|
+
output_logits.append(outputs)
|
|
177
|
+
|
|
171
178
|
padded_cache_lengths += current_padded_cache_lengths
|
|
172
179
|
step += num_processed_tokens
|
|
173
180
|
|
|
174
|
-
if
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def decode_forward(
|
|
182
|
-
self,
|
|
183
|
-
inputs: torch.Tensor,
|
|
184
|
-
cache_position: torch.Tensor = None,
|
|
185
|
-
block_tables: torch.Tensor = None,
|
|
186
|
-
is_external_block_tables: bool = None,
|
|
187
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
188
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
189
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
190
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
lora_int_ids: Optional[torch.Tensor] = None,
|
|
192
|
-
) -> torch.FloatTensor:
|
|
193
|
-
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
194
|
-
if self.lora_int_ids is None:
|
|
195
|
-
raise ValueError(
|
|
196
|
-
"lora_int_id is required when using LoRA. "
|
|
197
|
-
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
lora_int_ids = self.lora_int_ids
|
|
201
|
-
|
|
202
|
-
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
203
|
-
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
204
|
-
|
|
205
|
-
batch_size = inputs.shape[0]
|
|
206
|
-
if batch_size != self.batch_size:
|
|
207
|
-
raise RuntimeError(
|
|
208
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
209
|
-
)
|
|
181
|
+
if self.rbln_config.output_hidden_states:
|
|
182
|
+
num_hidden_layers = len(all_hidden_states[0]) - 1
|
|
183
|
+
concatenated_hidden_states = ()
|
|
184
|
+
for l_idx in range(num_hidden_layers + 1):
|
|
185
|
+
l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
|
|
186
|
+
l_hidden_states = l_hidden_states[:, :query_length, :]
|
|
187
|
+
concatenated_hidden_states += (l_hidden_states,)
|
|
210
188
|
|
|
211
|
-
|
|
212
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
189
|
+
all_hidden_states = concatenated_hidden_states
|
|
213
190
|
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
if local_block_tables is None:
|
|
219
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
191
|
+
# Aggregate output_logits
|
|
192
|
+
output_logits = torch.concat(output_logits, dim=-2)
|
|
193
|
+
if self.rbln_config.logits_to_keep > 0:
|
|
194
|
+
output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
|
|
220
195
|
else:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
232
|
-
)
|
|
233
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
234
|
-
|
|
235
|
-
attention_mask = self.dec_attn_mask
|
|
236
|
-
|
|
237
|
-
if self.batch_size < block_tables.shape[0]:
|
|
238
|
-
block_tables = block_tables[: self.batch_size]
|
|
196
|
+
output_logits = output_logits[:, :query_length, :]
|
|
197
|
+
# index copy for masked output_logits
|
|
198
|
+
if attention_mask is not None:
|
|
199
|
+
new_output_logits = torch.full(
|
|
200
|
+
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
201
|
+
fill_value=1e-10,
|
|
202
|
+
dtype=output_logits.dtype,
|
|
203
|
+
)
|
|
204
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
205
|
+
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
239
206
|
|
|
240
|
-
|
|
241
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
207
|
+
output_logits = new_output_logits
|
|
242
208
|
|
|
243
|
-
|
|
209
|
+
if not is_external_block_tables:
|
|
210
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
244
211
|
|
|
245
|
-
return
|
|
212
|
+
return RBLNGemma3ForCausalLMOutput(
|
|
213
|
+
logits=output_logits,
|
|
214
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
215
|
+
attention_mask=chunked_attention_mask,
|
|
216
|
+
hidden_states=all_hidden_states,
|
|
217
|
+
)
|
|
@@ -299,28 +299,60 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
299
299
|
generate_idx: Optional[torch.Tensor] = None,
|
|
300
300
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
301
301
|
position_ids: Optional[torch.Tensor] = None,
|
|
302
|
+
output_hidden_states: Optional[bool] = None,
|
|
302
303
|
**lm_kwargs: Dict[str, Any],
|
|
303
304
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
305
|
+
output_hidden_states = (
|
|
306
|
+
output_hidden_states
|
|
307
|
+
if output_hidden_states is not None
|
|
308
|
+
else self.rbln_config.language_model.output_hidden_states
|
|
309
|
+
)
|
|
310
|
+
if output_hidden_states != self.rbln_config.language_model.output_hidden_states:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} "
|
|
313
|
+
f"Please compile again with the correct argument."
|
|
314
|
+
)
|
|
315
|
+
|
|
304
316
|
# prefill
|
|
305
317
|
if cache_position is None:
|
|
306
318
|
logits = []
|
|
307
319
|
inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
|
|
308
320
|
batch_size = inputs_embeds.shape[0]
|
|
309
321
|
|
|
322
|
+
all_hidden_states = (
|
|
323
|
+
tuple(
|
|
324
|
+
torch.zeros(
|
|
325
|
+
batch_size,
|
|
326
|
+
inputs_embeds.shape[1],
|
|
327
|
+
self.config.text_config.hidden_size,
|
|
328
|
+
dtype=self.rbln_config.torch_dtype,
|
|
329
|
+
)
|
|
330
|
+
for _ in range(self.config.text_config.num_hidden_layers + 1)
|
|
331
|
+
)
|
|
332
|
+
if self.rbln_config.language_model.output_hidden_states
|
|
333
|
+
else None
|
|
334
|
+
)
|
|
335
|
+
|
|
310
336
|
for b_idx in range(batch_size):
|
|
311
337
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
312
338
|
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
313
339
|
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
314
340
|
|
|
315
|
-
|
|
341
|
+
outputs = self.language_model.prefill_decoder(
|
|
316
342
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
317
343
|
attention_mask=attention_mask[b_idx],
|
|
318
344
|
cache_position=cache_position,
|
|
319
345
|
batch_idx=b_idx,
|
|
320
346
|
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
321
347
|
)
|
|
322
|
-
padded_cache_lengths[b_idx] +=
|
|
323
|
-
logits.append(
|
|
348
|
+
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
|
|
349
|
+
logits.append(outputs.logits)
|
|
350
|
+
if self.rbln_config.language_model.output_hidden_states:
|
|
351
|
+
for l_idx in range(self.config.text_config.num_hidden_layers + 1):
|
|
352
|
+
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
|
|
353
|
+
all_hidden_states[l_idx][b_idx].index_copy_(
|
|
354
|
+
dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
|
|
355
|
+
)
|
|
324
356
|
|
|
325
357
|
logits = torch.cat(logits, dim=0)
|
|
326
358
|
# decoder
|
|
@@ -334,15 +366,20 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
334
366
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
335
367
|
)
|
|
336
368
|
|
|
337
|
-
|
|
369
|
+
outputs = self.language_model.decoders[batch_size](
|
|
338
370
|
input_ids=input_ids,
|
|
339
371
|
inputs_embeds=inputs_embeds,
|
|
340
372
|
cache_position=cache_position,
|
|
341
373
|
position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
|
|
342
|
-
)
|
|
374
|
+
)
|
|
375
|
+
logits = outputs.logits
|
|
376
|
+
all_hidden_states = outputs.hidden_states
|
|
343
377
|
|
|
344
378
|
return RBLNDecoderOnlyOutput(
|
|
345
|
-
logits=logits,
|
|
379
|
+
logits=logits,
|
|
380
|
+
generate_idx=generate_idx,
|
|
381
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
382
|
+
hidden_states=all_hidden_states,
|
|
346
383
|
)
|
|
347
384
|
|
|
348
385
|
|
|
@@ -403,26 +440,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
403
440
|
)
|
|
404
441
|
return embed_tokens
|
|
405
442
|
|
|
406
|
-
@classmethod
|
|
407
|
-
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
408
|
-
sliding_window = getattr(model_config, "sliding_window", None)
|
|
409
|
-
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
410
|
-
if sliding_window_pattern is None:
|
|
411
|
-
if hasattr(model_config, "layer_types"):
|
|
412
|
-
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
413
|
-
sliding_window_pattern = first_full_attention_index + 1
|
|
414
|
-
else:
|
|
415
|
-
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
416
|
-
|
|
417
|
-
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
418
|
-
rbln_config.cache_impl = "hybrid"
|
|
419
|
-
rbln_config.sliding_window = sliding_window
|
|
420
|
-
rbln_config.sliding_window_layers = [
|
|
421
|
-
i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
|
|
422
|
-
]
|
|
423
|
-
|
|
424
|
-
return rbln_config
|
|
425
|
-
|
|
426
443
|
@classmethod
|
|
427
444
|
def _update_submodule_config(
|
|
428
445
|
cls,
|
|
@@ -150,7 +150,7 @@ class _GroundingDinoEncoder(torch.nn.Module):
|
|
|
150
150
|
all_attn_fused_vision = () if output_attentions else None
|
|
151
151
|
all_attn_enhanced_text = () if output_attentions else None
|
|
152
152
|
all_attn_deformable = () if output_attentions else None
|
|
153
|
-
for
|
|
153
|
+
for _, encoder_layer in enumerate(self.layers):
|
|
154
154
|
if output_hidden_states:
|
|
155
155
|
encoder_vision_states += (vision_features,)
|
|
156
156
|
encoder_text_states += (text_features,)
|
|
@@ -304,7 +304,6 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
|
|
|
304
304
|
for feature_map, mask in vision_features:
|
|
305
305
|
# position encoding
|
|
306
306
|
position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
307
|
-
vision_features, position_embeddings_list
|
|
308
307
|
|
|
309
308
|
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
310
309
|
feature_maps = []
|
|
@@ -337,7 +337,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
337
337
|
pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
|
|
338
338
|
|
|
339
339
|
vision_out_buffer = []
|
|
340
|
-
for
|
|
340
|
+
for _ in range(self.config.vision_config.num_hidden_layers + 2):
|
|
341
341
|
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
|
|
342
342
|
if pooler_out_size is not None:
|
|
343
343
|
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
|
|
@@ -300,7 +300,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
|
|
|
300
300
|
]
|
|
301
301
|
pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size]
|
|
302
302
|
vision_out_buffer = []
|
|
303
|
-
for
|
|
303
|
+
for _ in range(self.config.vision_config.num_hidden_layers + 2):
|
|
304
304
|
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
|
|
305
305
|
vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
|
|
306
306
|
|
|
@@ -12,13 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from transformers import PretrainedConfig
|
|
16
15
|
|
|
17
16
|
from ....utils import logging
|
|
18
17
|
from ...models.decoderonly import (
|
|
19
18
|
RBLNDecoderOnlyModel,
|
|
20
19
|
RBLNDecoderOnlyModelForCausalLM,
|
|
21
|
-
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
22
20
|
)
|
|
23
21
|
from .mistral_architecture import MistralWrapper
|
|
24
22
|
|
|
@@ -85,16 +83,6 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
85
83
|
|
|
86
84
|
_decoder_wrapper_cls = MistralWrapper
|
|
87
85
|
|
|
88
|
-
@classmethod
|
|
89
|
-
def _update_sliding_window_config(
|
|
90
|
-
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
91
|
-
):
|
|
92
|
-
rbln_config.cache_impl = "sliding_window"
|
|
93
|
-
rbln_config.sliding_window = model_config.sliding_window
|
|
94
|
-
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
95
|
-
|
|
96
|
-
return rbln_config
|
|
97
|
-
|
|
98
86
|
|
|
99
87
|
class RBLNMistralModel(RBLNDecoderOnlyModel):
|
|
100
88
|
"""
|
|
@@ -103,13 +91,3 @@ class RBLNMistralModel(RBLNDecoderOnlyModel):
|
|
|
103
91
|
"""
|
|
104
92
|
|
|
105
93
|
_decoder_wrapper_cls = MistralWrapper
|
|
106
|
-
|
|
107
|
-
@classmethod
|
|
108
|
-
def _update_sliding_window_config(
|
|
109
|
-
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
110
|
-
):
|
|
111
|
-
rbln_config.cache_impl = "sliding_window"
|
|
112
|
-
rbln_config.sliding_window = model_config.sliding_window
|
|
113
|
-
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
114
|
-
|
|
115
|
-
return rbln_config
|