optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +20 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +22 -3
- optimum/rbln/transformers/models/__init__.py +23 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +160 -88
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +35 -24
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch.nn as nn
|
16
|
+
from transformers import PreTrainedModel
|
17
|
+
|
18
|
+
from ....utils import logging
|
19
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
20
|
+
from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
21
|
+
from .opt_architecture import OPTWrapper
|
22
|
+
|
23
|
+
|
24
|
+
logger = logging.get_logger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class MLP(nn.Module):
|
28
|
+
def __init__(self, fc1, fc2, activation_fn):
|
29
|
+
super(MLP, self).__init__()
|
30
|
+
self.fc1 = fc1
|
31
|
+
self.fc2 = fc2
|
32
|
+
self.activation_fn = activation_fn
|
33
|
+
|
34
|
+
def forward(self, x):
|
35
|
+
x = self.fc1(x)
|
36
|
+
x = self.activation_fn(x)
|
37
|
+
x = self.fc2(x)
|
38
|
+
return x
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
42
|
+
"""
|
43
|
+
The OPT Model transformer with a language modeling head (linear layer) on top.
|
44
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
45
|
+
|
46
|
+
A class to convert and run pre-trained transformers based OPTForCausalLM model on RBLN devices.
|
47
|
+
It implements the methods to convert a pre-trained transformers OPTForCausalLM model into a RBLN transformer model by:
|
48
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
49
|
+
- compiling the resulting graph using the RBLN compiler.
|
50
|
+
"""
|
51
|
+
|
52
|
+
_decoder_wrapper_cls = OPTWrapper
|
53
|
+
_use_rotary_emb = False
|
54
|
+
|
55
|
+
def modify_opt_decoder_layer(layer):
|
56
|
+
mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
|
57
|
+
layer.mlp = mlp
|
58
|
+
del layer.fc1
|
59
|
+
del layer.fc2
|
60
|
+
del layer.activation_fn
|
61
|
+
|
62
|
+
return layer
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
66
|
+
wrapper_cfg = {
|
67
|
+
"max_seq_len": rbln_config.max_seq_len,
|
68
|
+
"attn_impl": rbln_config.attn_impl,
|
69
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
70
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
71
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
72
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
73
|
+
}
|
74
|
+
|
75
|
+
for i in range(len(model.model.decoder.layers)):
|
76
|
+
model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
|
77
|
+
|
78
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
import torch.nn as nn
|
18
|
+
|
19
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
20
|
+
DecoderOnlyAttention,
|
21
|
+
DecoderOnlyForCausalLM,
|
22
|
+
DecoderOnlyLayer,
|
23
|
+
DecoderOnlyModel,
|
24
|
+
DecoderOnlyWrapper,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from transformers import OPTForCausalLM
|
30
|
+
|
31
|
+
|
32
|
+
class OPTWrapper(DecoderOnlyWrapper):
|
33
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "OPTForCausalLM", max_seq_len: int):
|
34
|
+
if self.attn_impl != "eager":
|
35
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
36
|
+
|
37
|
+
new_layers = []
|
38
|
+
|
39
|
+
for layer in causal_lm.model.decoder.layers:
|
40
|
+
new_self_attn = OPTAttention(
|
41
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
42
|
+
)
|
43
|
+
new_layer = OPTDecoderLayer(layer, new_self_attn)
|
44
|
+
new_layers.append(new_layer)
|
45
|
+
new_model = OPTModel(causal_lm.model.decoder, new_layers, max_seq_len=max_seq_len, use_learned_pos_emb=True)
|
46
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
47
|
+
return new_causal_lm
|
48
|
+
|
49
|
+
|
50
|
+
class OPTAttention(DecoderOnlyAttention):
|
51
|
+
def __post_init__(self):
|
52
|
+
self.k_proj = self._original_mod.k_proj
|
53
|
+
self.v_proj = self._original_mod.v_proj
|
54
|
+
self.q_proj = self._original_mod.q_proj
|
55
|
+
self.o_proj = self._original_mod.out_proj
|
56
|
+
|
57
|
+
|
58
|
+
class OPTModel(DecoderOnlyModel):
|
59
|
+
def get_embedding(self) -> nn.Embedding:
|
60
|
+
return self._original_mod.embed_tokens
|
61
|
+
|
62
|
+
def get_pos_embedding(self):
|
63
|
+
return self._original_mod.embed_positions
|
64
|
+
|
65
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
66
|
+
return self._original_mod.final_layer_norm
|
67
|
+
|
68
|
+
|
69
|
+
class OPTDecoderLayer(DecoderOnlyLayer):
|
70
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
71
|
+
return self._original_mod.self_attn_layer_norm
|
72
|
+
|
73
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
74
|
+
return self._original_mod.final_layer_norm
|
@@ -371,6 +371,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
371
371
|
query_length: int,
|
372
372
|
use_inputs_embeds: bool,
|
373
373
|
use_attention_mask: bool,
|
374
|
+
use_position_ids: bool,
|
374
375
|
max_seq_len: int,
|
375
376
|
kvcache_block_size: int,
|
376
377
|
kvcache_num_blocks: int,
|
@@ -384,6 +385,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
384
385
|
query_length,
|
385
386
|
use_inputs_embeds,
|
386
387
|
use_attention_mask,
|
388
|
+
use_position_ids,
|
387
389
|
max_seq_len,
|
388
390
|
kvcache_block_size,
|
389
391
|
kvcache_num_blocks,
|
@@ -392,8 +394,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
392
394
|
hidden_size,
|
393
395
|
head_dim,
|
394
396
|
)
|
395
|
-
pos_idx =
|
396
|
-
pos_idx = pos_idx if use_attention_mask else pos_idx - 1
|
397
|
+
pos_idx = 4 if query_length > 1 else 5
|
397
398
|
input_info.insert(pos_idx, ("position_emb", [2, batch_size, 1, query_length, head_dim], "float32"))
|
398
399
|
|
399
400
|
return input_info
|
@@ -562,7 +563,8 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
562
563
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
563
564
|
cache_position: Optional[torch.LongTensor] = None,
|
564
565
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
565
|
-
generate_idx: torch.Tensor = None,
|
566
|
+
generate_idx: Optional[torch.Tensor] = None,
|
567
|
+
return_dict: Optional[bool] = None,
|
566
568
|
**kwargs,
|
567
569
|
) -> RBLNDecoderOnlyOutput:
|
568
570
|
# Prefill
|
@@ -584,25 +586,29 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
584
586
|
for b_idx in range(batch_size):
|
585
587
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
586
588
|
|
587
|
-
|
589
|
+
output = self.prefill_decoder(
|
588
590
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
589
591
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
590
592
|
cache_position=cache_position,
|
591
593
|
batch_idx=b_idx,
|
592
594
|
position_embed=position_embed[:, b_idx : b_idx + 1],
|
593
595
|
)
|
594
|
-
logits.append(
|
596
|
+
logits.append(output.logits)
|
595
597
|
logits = torch.cat(logits, dim=0)
|
596
598
|
# Decoder
|
597
599
|
else:
|
598
600
|
inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
|
599
|
-
|
601
|
+
output = self.decoder(
|
600
602
|
inputs_embeds=inputs_embeds,
|
601
603
|
cache_position=cache_position,
|
602
604
|
position_embed=position_embed,
|
603
605
|
)
|
606
|
+
logits = output.logits
|
604
607
|
|
605
|
-
|
606
|
-
logits
|
607
|
-
|
608
|
-
|
608
|
+
if not return_dict:
|
609
|
+
return logits, generate_idx
|
610
|
+
else:
|
611
|
+
return RBLNDecoderOnlyOutput(
|
612
|
+
logits=logits,
|
613
|
+
generate_idx=generate_idx,
|
614
|
+
)
|
@@ -157,58 +157,41 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
|
|
157
157
|
|
158
158
|
|
159
159
|
class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
160
|
-
def
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
*past_key_values,
|
191
|
-
) = args
|
192
|
-
else:
|
193
|
-
(
|
194
|
-
input_ids_or_inputs_embeds,
|
195
|
-
cache_position,
|
196
|
-
query_position,
|
197
|
-
block_tables,
|
198
|
-
position_emb,
|
199
|
-
*past_key_values,
|
200
|
-
) = args
|
201
|
-
attention_mask = None
|
202
|
-
|
203
|
-
else:
|
204
|
-
raise ValueError(f"Unknown phase: {self.phase}")
|
205
|
-
|
206
|
-
return self.forward_common(
|
207
|
-
input_ids_or_inputs_embeds,
|
160
|
+
def prepare_forward_args(self, *args):
|
161
|
+
args = list(args)
|
162
|
+
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
163
|
+
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
164
|
+
cache_position = args.pop(0)
|
165
|
+
block_tables = args.pop(0)
|
166
|
+
position_embeds = args.pop(0)
|
167
|
+
query_position = args.pop(0) if self.phase == "prefill" else None
|
168
|
+
position_ids = None
|
169
|
+
attention_mask = args.pop(0) if self.use_attention_mask else None
|
170
|
+
past_key_values = args
|
171
|
+
|
172
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
173
|
+
raise ValueError(
|
174
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
175
|
+
)
|
176
|
+
|
177
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
178
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
179
|
+
_past_key_values = []
|
180
|
+
for i in range(self.config.num_hidden_layers):
|
181
|
+
key_states = past_key_values[i * 2]
|
182
|
+
value_states = past_key_values[i * 2 + 1]
|
183
|
+
past_key_value = [key_states, value_states]
|
184
|
+
_past_key_values.append(past_key_value)
|
185
|
+
past_key_values = _past_key_values
|
186
|
+
|
187
|
+
return (
|
188
|
+
input_ids,
|
189
|
+
inputs_embeds,
|
208
190
|
cache_position,
|
209
|
-
attention_mask,
|
210
|
-
query_position,
|
211
191
|
block_tables,
|
212
|
-
|
213
|
-
|
192
|
+
query_position,
|
193
|
+
attention_mask,
|
194
|
+
position_ids,
|
195
|
+
past_key_values,
|
196
|
+
position_embeds,
|
214
197
|
)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_siglip import (
|
16
|
+
RBLNSiglipVisionModelConfig,
|
17
|
+
)
|
18
|
+
from .modeling_siglip import (
|
19
|
+
RBLNSiglipVisionModel,
|
20
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNSiglipVisionModelConfig(RBLNModelConfig):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
batch_size: Optional[int] = None,
|
24
|
+
image_size: Optional[int] = None,
|
25
|
+
interpolate_pos_encoding: Optional[bool] = None,
|
26
|
+
output_hidden_states: Optional[bool] = None,
|
27
|
+
**kwargs,
|
28
|
+
):
|
29
|
+
"""
|
30
|
+
Args:
|
31
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
32
|
+
image_size (Optional[int]): The size of input images. Can be an integer for square images,
|
33
|
+
a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
|
34
|
+
interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
|
35
|
+
output_hidden_states: (Optional[bool]): Whether to return hidden states.
|
36
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
40
|
+
"""
|
41
|
+
super().__init__(**kwargs)
|
42
|
+
self.batch_size = batch_size or 1
|
43
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
44
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
45
|
+
|
46
|
+
self.image_size = image_size
|
47
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding or False
|
48
|
+
self.output_hidden_states = output_hidden_states
|
49
|
+
|
50
|
+
@property
|
51
|
+
def image_width(self):
|
52
|
+
if isinstance(self.image_size, int):
|
53
|
+
return self.image_size
|
54
|
+
elif isinstance(self.image_size, (list, tuple)):
|
55
|
+
return self.image_size[1]
|
56
|
+
else:
|
57
|
+
return self.image_size["width"]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def image_height(self):
|
61
|
+
if isinstance(self.image_size, int):
|
62
|
+
return self.image_size
|
63
|
+
elif isinstance(self.image_size, (list, tuple)):
|
64
|
+
return self.image_size[0]
|
65
|
+
else:
|
66
|
+
return self.image_size["height"]
|
@@ -0,0 +1,146 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from transformers import SiglipVisionConfig, SiglipVisionModel
|
19
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
20
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModelOutput
|
21
|
+
|
22
|
+
from ....configuration_utils import RBLNCompileConfig
|
23
|
+
from ....modeling import RBLNModel
|
24
|
+
from ....utils.logging import get_logger
|
25
|
+
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
26
|
+
|
27
|
+
|
28
|
+
logger = get_logger(__name__)
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
32
|
+
|
33
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
34
|
+
|
35
|
+
|
36
|
+
class _SiglipVisionModel(torch.nn.Module):
|
37
|
+
def __init__(self, model: SiglipVisionModel, interpolate_pos_encoding: bool, output_hidden_states: bool):
|
38
|
+
super().__init__()
|
39
|
+
self.vision_model = model.vision_model
|
40
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding
|
41
|
+
self.output_hidden_states = output_hidden_states
|
42
|
+
|
43
|
+
def forward(self, inp):
|
44
|
+
enc_out = self.vision_model(
|
45
|
+
inp,
|
46
|
+
output_hidden_states=self.output_hidden_states,
|
47
|
+
return_dict=False,
|
48
|
+
interpolate_pos_encoding=self.interpolate_pos_encoding,
|
49
|
+
)
|
50
|
+
return tuple(x for x in enc_out if x is not None)
|
51
|
+
|
52
|
+
|
53
|
+
class RBLNSiglipVisionModel(RBLNModel):
|
54
|
+
@classmethod
|
55
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
|
56
|
+
wrapper_cfg = {
|
57
|
+
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
58
|
+
"output_hidden_states": rbln_config.output_hidden_states,
|
59
|
+
}
|
60
|
+
return _SiglipVisionModel(model, **wrapper_cfg).eval()
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def update_rbln_config_using_pipe(
|
64
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
65
|
+
) -> "RBLNDiffusionMixinConfig":
|
66
|
+
return rbln_config
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def _update_rbln_config(
|
70
|
+
cls,
|
71
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
72
|
+
model: Optional["PreTrainedModel"] = None,
|
73
|
+
model_config: "SiglipVisionConfig" = None,
|
74
|
+
rbln_config: Optional[RBLNSiglipVisionModelConfig] = None,
|
75
|
+
) -> RBLNSiglipVisionModelConfig:
|
76
|
+
if rbln_config.image_size is None:
|
77
|
+
rbln_config.image_size = getattr(model_config, "image_size", None)
|
78
|
+
|
79
|
+
if isinstance(rbln_config.image_size, int):
|
80
|
+
rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
|
81
|
+
if rbln_config.image_size is None:
|
82
|
+
raise ValueError("`rbln_image_size` should be specified!")
|
83
|
+
|
84
|
+
if rbln_config.output_hidden_states is None:
|
85
|
+
rbln_config.output_hidden_states = model_config.output_hidden_states
|
86
|
+
|
87
|
+
rbln_compile_config = RBLNCompileConfig(
|
88
|
+
input_info=[
|
89
|
+
(
|
90
|
+
"pixel_values",
|
91
|
+
[
|
92
|
+
rbln_config.batch_size,
|
93
|
+
3,
|
94
|
+
rbln_config.image_height,
|
95
|
+
rbln_config.image_width,
|
96
|
+
],
|
97
|
+
"float32",
|
98
|
+
)
|
99
|
+
]
|
100
|
+
)
|
101
|
+
|
102
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
103
|
+
return rbln_config
|
104
|
+
|
105
|
+
def forward(
|
106
|
+
self,
|
107
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
108
|
+
return_dict: bool = None,
|
109
|
+
interpolate_pos_encoding: bool = False,
|
110
|
+
**kwargs,
|
111
|
+
) -> Union[Tuple, SiglipVisionModelOutput]:
|
112
|
+
if len(kwargs) > 0 and any(kwargs.values()):
|
113
|
+
logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
|
114
|
+
|
115
|
+
if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
|
116
|
+
raise ValueError(
|
117
|
+
f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding}"
|
118
|
+
f"Please compile again with the correct argument."
|
119
|
+
)
|
120
|
+
output = super().forward(pixel_values, return_dict=return_dict)
|
121
|
+
return output
|
122
|
+
|
123
|
+
def _prepare_output(self, output, return_dict):
|
124
|
+
"""
|
125
|
+
Prepare model output based on return_dict flag.
|
126
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
127
|
+
"""
|
128
|
+
if not return_dict:
|
129
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
130
|
+
else:
|
131
|
+
last_hidden_state = (
|
132
|
+
output[0]
|
133
|
+
if self.rbln_config.interpolate_pos_encoding or self.rbln_config.output_hidden_states
|
134
|
+
else output
|
135
|
+
)
|
136
|
+
pooler_output = output[1] if self.rbln_config.interpolate_pos_encoding else None
|
137
|
+
if self.rbln_config.output_hidden_states:
|
138
|
+
hidden_states = (output[2:] if self.rbln_config.interpolate_pos_encoding else output[1:],)
|
139
|
+
else:
|
140
|
+
hidden_states = None
|
141
|
+
|
142
|
+
return BaseModelOutputWithPooling(
|
143
|
+
last_hidden_state=last_hidden_state,
|
144
|
+
pooler_output=pooler_output,
|
145
|
+
hidden_states=hidden_states,
|
146
|
+
)
|
@@ -313,6 +313,7 @@ class WhisperSelfAttention(WhisperAttention):
|
|
313
313
|
args["mask"] = attention_mask.unsqueeze(2)
|
314
314
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
|
315
315
|
else:
|
316
|
+
args["mask"] = None
|
316
317
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
|
317
318
|
|
318
319
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|