optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. optimum/rbln/__init__.py +20 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +9 -4
  4. optimum/rbln/modeling.py +7 -5
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/sliding_window_attn.py +111 -0
  9. optimum/rbln/transformers/__init__.py +22 -3
  10. optimum/rbln/transformers/models/__init__.py +23 -0
  11. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  12. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  13. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  14. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
  15. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
  16. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +160 -88
  17. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  18. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  19. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  20. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  21. optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
  22. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  23. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
  24. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  25. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  26. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  27. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  28. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  29. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  30. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  31. optimum/rbln/utils/submodule.py +13 -1
  32. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
  33. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +35 -24
  34. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
  35. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,78 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ from transformers import PreTrainedModel
17
+
18
+ from ....utils import logging
19
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
20
+ from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
+ from .opt_architecture import OPTWrapper
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class MLP(nn.Module):
28
+ def __init__(self, fc1, fc2, activation_fn):
29
+ super(MLP, self).__init__()
30
+ self.fc1 = fc1
31
+ self.fc2 = fc2
32
+ self.activation_fn = activation_fn
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.activation_fn(x)
37
+ x = self.fc2(x)
38
+ return x
39
+
40
+
41
+ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
42
+ """
43
+ The OPT Model transformer with a language modeling head (linear layer) on top.
44
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
45
+
46
+ A class to convert and run pre-trained transformers based OPTForCausalLM model on RBLN devices.
47
+ It implements the methods to convert a pre-trained transformers OPTForCausalLM model into a RBLN transformer model by:
48
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
49
+ - compiling the resulting graph using the RBLN compiler.
50
+ """
51
+
52
+ _decoder_wrapper_cls = OPTWrapper
53
+ _use_rotary_emb = False
54
+
55
+ def modify_opt_decoder_layer(layer):
56
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
57
+ layer.mlp = mlp
58
+ del layer.fc1
59
+ del layer.fc2
60
+ del layer.activation_fn
61
+
62
+ return layer
63
+
64
+ @classmethod
65
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
66
+ wrapper_cfg = {
67
+ "max_seq_len": rbln_config.max_seq_len,
68
+ "attn_impl": rbln_config.attn_impl,
69
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
70
+ "kvcache_block_size": rbln_config.kvcache_block_size,
71
+ "use_rotary_emb": cls._use_rotary_emb,
72
+ "use_attention_mask": rbln_config.use_attention_mask,
73
+ }
74
+
75
+ for i in range(len(model.model.decoder.layers)):
76
+ model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
77
+
78
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch.nn as nn
18
+
19
+ from ...models.decoderonly.decoderonly_architecture import (
20
+ DecoderOnlyAttention,
21
+ DecoderOnlyForCausalLM,
22
+ DecoderOnlyLayer,
23
+ DecoderOnlyModel,
24
+ DecoderOnlyWrapper,
25
+ )
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers import OPTForCausalLM
30
+
31
+
32
+ class OPTWrapper(DecoderOnlyWrapper):
33
+ def convert_to_rbln_causal_lm(self, causal_lm: "OPTForCausalLM", max_seq_len: int):
34
+ if self.attn_impl != "eager":
35
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
36
+
37
+ new_layers = []
38
+
39
+ for layer in causal_lm.model.decoder.layers:
40
+ new_self_attn = OPTAttention(
41
+ layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
42
+ )
43
+ new_layer = OPTDecoderLayer(layer, new_self_attn)
44
+ new_layers.append(new_layer)
45
+ new_model = OPTModel(causal_lm.model.decoder, new_layers, max_seq_len=max_seq_len, use_learned_pos_emb=True)
46
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
47
+ return new_causal_lm
48
+
49
+
50
+ class OPTAttention(DecoderOnlyAttention):
51
+ def __post_init__(self):
52
+ self.k_proj = self._original_mod.k_proj
53
+ self.v_proj = self._original_mod.v_proj
54
+ self.q_proj = self._original_mod.q_proj
55
+ self.o_proj = self._original_mod.out_proj
56
+
57
+
58
+ class OPTModel(DecoderOnlyModel):
59
+ def get_embedding(self) -> nn.Embedding:
60
+ return self._original_mod.embed_tokens
61
+
62
+ def get_pos_embedding(self):
63
+ return self._original_mod.embed_positions
64
+
65
+ def get_last_layernorm(self) -> nn.LayerNorm:
66
+ return self._original_mod.final_layer_norm
67
+
68
+
69
+ class OPTDecoderLayer(DecoderOnlyLayer):
70
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.self_attn_layer_norm
72
+
73
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
74
+ return self._original_mod.final_layer_norm
@@ -371,6 +371,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
371
371
  query_length: int,
372
372
  use_inputs_embeds: bool,
373
373
  use_attention_mask: bool,
374
+ use_position_ids: bool,
374
375
  max_seq_len: int,
375
376
  kvcache_block_size: int,
376
377
  kvcache_num_blocks: int,
@@ -384,6 +385,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
384
385
  query_length,
385
386
  use_inputs_embeds,
386
387
  use_attention_mask,
388
+ use_position_ids,
387
389
  max_seq_len,
388
390
  kvcache_block_size,
389
391
  kvcache_num_blocks,
@@ -392,8 +394,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
392
394
  hidden_size,
393
395
  head_dim,
394
396
  )
395
- pos_idx = 5 if query_length > 1 else 4
396
- pos_idx = pos_idx if use_attention_mask else pos_idx - 1
397
+ pos_idx = 4 if query_length > 1 else 5
397
398
  input_info.insert(pos_idx, ("position_emb", [2, batch_size, 1, query_length, head_dim], "float32"))
398
399
 
399
400
  return input_info
@@ -562,7 +563,8 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
562
563
  video_grid_thw: Optional[torch.LongTensor] = None,
563
564
  cache_position: Optional[torch.LongTensor] = None,
564
565
  second_per_grid_ts: Optional[torch.Tensor] = None,
565
- generate_idx: torch.Tensor = None,
566
+ generate_idx: Optional[torch.Tensor] = None,
567
+ return_dict: Optional[bool] = None,
566
568
  **kwargs,
567
569
  ) -> RBLNDecoderOnlyOutput:
568
570
  # Prefill
@@ -584,25 +586,29 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
584
586
  for b_idx in range(batch_size):
585
587
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
586
588
 
587
- logit = self.prefill_decoder(
589
+ output = self.prefill_decoder(
588
590
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
589
591
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
590
592
  cache_position=cache_position,
591
593
  batch_idx=b_idx,
592
594
  position_embed=position_embed[:, b_idx : b_idx + 1],
593
595
  )
594
- logits.append(logit)
596
+ logits.append(output.logits)
595
597
  logits = torch.cat(logits, dim=0)
596
598
  # Decoder
597
599
  else:
598
600
  inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
599
- logits = self.decoder(
601
+ output = self.decoder(
600
602
  inputs_embeds=inputs_embeds,
601
603
  cache_position=cache_position,
602
604
  position_embed=position_embed,
603
605
  )
606
+ logits = output.logits
604
607
 
605
- return RBLNDecoderOnlyOutput(
606
- logits=logits,
607
- generate_idx=generate_idx,
608
- )
608
+ if not return_dict:
609
+ return logits, generate_idx
610
+ else:
611
+ return RBLNDecoderOnlyOutput(
612
+ logits=logits,
613
+ generate_idx=generate_idx,
614
+ )
@@ -157,58 +157,41 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
157
157
 
158
158
 
159
159
  class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
160
- def forward(self, *args):
161
- if self.phase == "decode":
162
- if self.use_attention_mask:
163
- (
164
- input_ids_or_inputs_embeds,
165
- cache_position,
166
- attention_mask,
167
- block_tables,
168
- position_emb,
169
- *past_key_values,
170
- ) = args
171
- else:
172
- (
173
- input_ids_or_inputs_embeds,
174
- cache_position,
175
- block_tables,
176
- position_emb,
177
- *past_key_values,
178
- ) = args
179
- attention_mask = None
180
- query_position = None
181
- elif self.phase == "prefill":
182
- if self.use_attention_mask:
183
- (
184
- input_ids_or_inputs_embeds,
185
- cache_position,
186
- attention_mask,
187
- query_position,
188
- block_tables,
189
- position_emb,
190
- *past_key_values,
191
- ) = args
192
- else:
193
- (
194
- input_ids_or_inputs_embeds,
195
- cache_position,
196
- query_position,
197
- block_tables,
198
- position_emb,
199
- *past_key_values,
200
- ) = args
201
- attention_mask = None
202
-
203
- else:
204
- raise ValueError(f"Unknown phase: {self.phase}")
205
-
206
- return self.forward_common(
207
- input_ids_or_inputs_embeds,
160
+ def prepare_forward_args(self, *args):
161
+ args = list(args)
162
+ input_ids = None if self.use_inputs_embeds else args.pop(0)
163
+ inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
164
+ cache_position = args.pop(0)
165
+ block_tables = args.pop(0)
166
+ position_embeds = args.pop(0)
167
+ query_position = args.pop(0) if self.phase == "prefill" else None
168
+ position_ids = None
169
+ attention_mask = args.pop(0) if self.use_attention_mask else None
170
+ past_key_values = args
171
+
172
+ if len(past_key_values) != 2 * self.num_hidden_layers:
173
+ raise ValueError(
174
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
175
+ )
176
+
177
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
178
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
179
+ _past_key_values = []
180
+ for i in range(self.config.num_hidden_layers):
181
+ key_states = past_key_values[i * 2]
182
+ value_states = past_key_values[i * 2 + 1]
183
+ past_key_value = [key_states, value_states]
184
+ _past_key_values.append(past_key_value)
185
+ past_key_values = _past_key_values
186
+
187
+ return (
188
+ input_ids,
189
+ inputs_embeds,
208
190
  cache_position,
209
- attention_mask,
210
- query_position,
211
191
  block_tables,
212
- position_emb,
213
- *past_key_values,
192
+ query_position,
193
+ attention_mask,
194
+ position_ids,
195
+ past_key_values,
196
+ position_embeds,
214
197
  )
@@ -476,6 +476,8 @@ class Seq2SeqSelfAttention(nn.Module):
476
476
  ]
477
477
  if attention_mask is not None:
478
478
  args.insert(3, attention_mask.unsqueeze(2))
479
+ else:
480
+ args.append(None)
479
481
 
480
482
  attn_output = self.attn_decode(*args)
481
483
 
@@ -0,0 +1,20 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_siglip import (
16
+ RBLNSiglipVisionModelConfig,
17
+ )
18
+ from .modeling_siglip import (
19
+ RBLNSiglipVisionModel,
20
+ )
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ batch_size: Optional[int] = None,
24
+ image_size: Optional[int] = None,
25
+ interpolate_pos_encoding: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ **kwargs,
28
+ ):
29
+ """
30
+ Args:
31
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
32
+ image_size (Optional[int]): The size of input images. Can be an integer for square images,
33
+ a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
34
+ interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
35
+ output_hidden_states: (Optional[bool]): Whether to return hidden states.
36
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.batch_size = batch_size or 1
43
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
44
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
45
+
46
+ self.image_size = image_size
47
+ self.interpolate_pos_encoding = interpolate_pos_encoding or False
48
+ self.output_hidden_states = output_hidden_states
49
+
50
+ @property
51
+ def image_width(self):
52
+ if isinstance(self.image_size, int):
53
+ return self.image_size
54
+ elif isinstance(self.image_size, (list, tuple)):
55
+ return self.image_size[1]
56
+ else:
57
+ return self.image_size["width"]
58
+
59
+ @property
60
+ def image_height(self):
61
+ if isinstance(self.image_size, int):
62
+ return self.image_size
63
+ elif isinstance(self.image_size, (list, tuple)):
64
+ return self.image_size[0]
65
+ else:
66
+ return self.image_size["height"]
@@ -0,0 +1,146 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers import SiglipVisionConfig, SiglipVisionModel
19
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
20
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModelOutput
21
+
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
24
+ from ....utils.logging import get_logger
25
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
26
+
27
+
28
+ logger = get_logger(__name__)
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
32
+
33
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
34
+
35
+
36
+ class _SiglipVisionModel(torch.nn.Module):
37
+ def __init__(self, model: SiglipVisionModel, interpolate_pos_encoding: bool, output_hidden_states: bool):
38
+ super().__init__()
39
+ self.vision_model = model.vision_model
40
+ self.interpolate_pos_encoding = interpolate_pos_encoding
41
+ self.output_hidden_states = output_hidden_states
42
+
43
+ def forward(self, inp):
44
+ enc_out = self.vision_model(
45
+ inp,
46
+ output_hidden_states=self.output_hidden_states,
47
+ return_dict=False,
48
+ interpolate_pos_encoding=self.interpolate_pos_encoding,
49
+ )
50
+ return tuple(x for x in enc_out if x is not None)
51
+
52
+
53
+ class RBLNSiglipVisionModel(RBLNModel):
54
+ @classmethod
55
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
56
+ wrapper_cfg = {
57
+ "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
58
+ "output_hidden_states": rbln_config.output_hidden_states,
59
+ }
60
+ return _SiglipVisionModel(model, **wrapper_cfg).eval()
61
+
62
+ @classmethod
63
+ def update_rbln_config_using_pipe(
64
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
65
+ ) -> "RBLNDiffusionMixinConfig":
66
+ return rbln_config
67
+
68
+ @classmethod
69
+ def _update_rbln_config(
70
+ cls,
71
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
72
+ model: Optional["PreTrainedModel"] = None,
73
+ model_config: "SiglipVisionConfig" = None,
74
+ rbln_config: Optional[RBLNSiglipVisionModelConfig] = None,
75
+ ) -> RBLNSiglipVisionModelConfig:
76
+ if rbln_config.image_size is None:
77
+ rbln_config.image_size = getattr(model_config, "image_size", None)
78
+
79
+ if isinstance(rbln_config.image_size, int):
80
+ rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
81
+ if rbln_config.image_size is None:
82
+ raise ValueError("`rbln_image_size` should be specified!")
83
+
84
+ if rbln_config.output_hidden_states is None:
85
+ rbln_config.output_hidden_states = model_config.output_hidden_states
86
+
87
+ rbln_compile_config = RBLNCompileConfig(
88
+ input_info=[
89
+ (
90
+ "pixel_values",
91
+ [
92
+ rbln_config.batch_size,
93
+ 3,
94
+ rbln_config.image_height,
95
+ rbln_config.image_width,
96
+ ],
97
+ "float32",
98
+ )
99
+ ]
100
+ )
101
+
102
+ rbln_config.set_compile_cfgs([rbln_compile_config])
103
+ return rbln_config
104
+
105
+ def forward(
106
+ self,
107
+ pixel_values: Optional[torch.FloatTensor] = None,
108
+ return_dict: bool = None,
109
+ interpolate_pos_encoding: bool = False,
110
+ **kwargs,
111
+ ) -> Union[Tuple, SiglipVisionModelOutput]:
112
+ if len(kwargs) > 0 and any(kwargs.values()):
113
+ logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
114
+
115
+ if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
116
+ raise ValueError(
117
+ f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding}"
118
+ f"Please compile again with the correct argument."
119
+ )
120
+ output = super().forward(pixel_values, return_dict=return_dict)
121
+ return output
122
+
123
+ def _prepare_output(self, output, return_dict):
124
+ """
125
+ Prepare model output based on return_dict flag.
126
+ This method can be overridden by subclasses to provide task-specific output handling.
127
+ """
128
+ if not return_dict:
129
+ return (output,) if not isinstance(output, (tuple, list)) else output
130
+ else:
131
+ last_hidden_state = (
132
+ output[0]
133
+ if self.rbln_config.interpolate_pos_encoding or self.rbln_config.output_hidden_states
134
+ else output
135
+ )
136
+ pooler_output = output[1] if self.rbln_config.interpolate_pos_encoding else None
137
+ if self.rbln_config.output_hidden_states:
138
+ hidden_states = (output[2:] if self.rbln_config.interpolate_pos_encoding else output[1:],)
139
+ else:
140
+ hidden_states = None
141
+
142
+ return BaseModelOutputWithPooling(
143
+ last_hidden_state=last_hidden_state,
144
+ pooler_output=pooler_output,
145
+ hidden_states=hidden_states,
146
+ )
@@ -313,6 +313,7 @@ class WhisperSelfAttention(WhisperAttention):
313
313
  args["mask"] = attention_mask.unsqueeze(2)
314
314
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
315
315
  else:
316
+ args["mask"] = None
316
317
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
317
318
 
318
319
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)