optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_pixtral import RBLNPixtralVisionModelConfig
|
|
16
|
+
from .modeling_pixtral import RBLNPixtralVisionModel
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNPixtralVisionModelConfig(RBLNModelConfig):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
max_image_size: Tuple = None,
|
|
24
|
+
batch_size: Optional[int] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
**kwargs: Any,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
max_image_size (Tuple): The size of max input images. A tuple (max_height, max_width)
|
|
31
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
|
32
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If batch_size is not a positive integer.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.batch_size = batch_size or 1
|
|
39
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
40
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
41
|
+
|
|
42
|
+
self.max_image_size = max_image_size
|
|
43
|
+
self.output_hidden_states = output_hidden_states
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import rebel
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
from transformers import PixtralVisionConfig, PixtralVisionModel
|
|
22
|
+
from transformers.modeling_outputs import BaseModelOutput
|
|
23
|
+
from transformers.modeling_utils import no_init_weights
|
|
24
|
+
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralRotaryEmbedding
|
|
25
|
+
|
|
26
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
27
|
+
from ....modeling import RBLNModel
|
|
28
|
+
from ....utils.logging import get_logger
|
|
29
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
30
|
+
from .configuration_pixtral import RBLNPixtralVisionModelConfig
|
|
31
|
+
from .pixtral_architecture import PixtralAttention
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RBLNRuntimePixtralVisionModel(RBLNPytorchRuntime):
|
|
41
|
+
mandatory_members = ["main_input_name"]
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
runtime: rebel.Runtime,
|
|
46
|
+
config: PixtralVisionConfig,
|
|
47
|
+
rbln_config: RBLNPixtralVisionModelConfig,
|
|
48
|
+
**kwargs: Any,
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__(runtime, **kwargs)
|
|
51
|
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
|
52
|
+
self.patch_size = config.patch_size
|
|
53
|
+
self.image_size = config.image_size
|
|
54
|
+
self.hidden_size = config.hidden_size
|
|
55
|
+
self.max_image_size = rbln_config.max_image_size
|
|
56
|
+
|
|
57
|
+
def forward(
|
|
58
|
+
self,
|
|
59
|
+
pixel_values: torch.Tensor,
|
|
60
|
+
image_sizes: torch.Tensor,
|
|
61
|
+
output_hidden_states: Optional[bool] = None,
|
|
62
|
+
return_dict: Optional[bool] = None,
|
|
63
|
+
**kwargs,
|
|
64
|
+
):
|
|
65
|
+
if pixel_values.shape[2] > self.max_image_size[0] or pixel_values.shape[3] > self.max_image_size[1]:
|
|
66
|
+
raise ValueError("The height() and width of pixel_values can't be larger than max_image_size.")
|
|
67
|
+
|
|
68
|
+
if pixel_values.shape[2] != self.max_image_size[0] or pixel_values.shape[3] != self.max_image_size[1]:
|
|
69
|
+
padded_pixel_values = [
|
|
70
|
+
torch.nn.functional.pad(
|
|
71
|
+
image,
|
|
72
|
+
pad=(
|
|
73
|
+
0,
|
|
74
|
+
self.max_image_size[1] - pixel_values.shape[3],
|
|
75
|
+
0,
|
|
76
|
+
self.max_image_size[0] - pixel_values.shape[2],
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
for image in pixel_values
|
|
80
|
+
]
|
|
81
|
+
pixel_values = torch.stack(padded_pixel_values)
|
|
82
|
+
|
|
83
|
+
batch_size, _, H_max, W_max = pixel_values.shape
|
|
84
|
+
H_max_p = H_max // self.patch_size
|
|
85
|
+
W_max_p = W_max // self.patch_size
|
|
86
|
+
|
|
87
|
+
final_hidden_states = None
|
|
88
|
+
|
|
89
|
+
last_hidden_state_list = []
|
|
90
|
+
if output_hidden_states:
|
|
91
|
+
batch_hidden_states_list = []
|
|
92
|
+
|
|
93
|
+
for i in range(batch_size):
|
|
94
|
+
h_patched_original = image_sizes[i, 0] // self.patch_size
|
|
95
|
+
w_patched_original = image_sizes[i, 1] // self.patch_size
|
|
96
|
+
|
|
97
|
+
single_pixel_values = pixel_values[i : i + 1]
|
|
98
|
+
patch_embed = self.patch_conv(single_pixel_values)
|
|
99
|
+
patch_embed_seq = patch_embed[:, :, :h_patched_original, :w_patched_original].flatten(2).transpose(1, 2)
|
|
100
|
+
patch_embed_seq = self.ln_pre(patch_embed_seq)
|
|
101
|
+
patch_embed_seq = nn.functional.pad(
|
|
102
|
+
patch_embed_seq, (0, 0, 0, H_max_p * W_max_p - patch_embed_seq.shape[1]), "constant", value=0
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
max_w_from_config = self.image_size // self.patch_size
|
|
106
|
+
mesh = torch.meshgrid(torch.arange(h_patched_original), torch.arange(w_patched_original), indexing="ij")
|
|
107
|
+
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
|
|
108
|
+
ids = h_grid * max_w_from_config + v_grid
|
|
109
|
+
position_ids = ids[:, 0]
|
|
110
|
+
|
|
111
|
+
position_embeddings = self.patch_positional_embedding(patch_embed_seq, position_ids)
|
|
112
|
+
cos = nn.functional.pad(
|
|
113
|
+
position_embeddings[0],
|
|
114
|
+
(0, 0, 0, H_max_p * W_max_p - position_embeddings[0].shape[0]),
|
|
115
|
+
"constant",
|
|
116
|
+
value=0,
|
|
117
|
+
)
|
|
118
|
+
sin = nn.functional.pad(
|
|
119
|
+
position_embeddings[1],
|
|
120
|
+
(0, 0, 0, H_max_p * W_max_p - position_embeddings[1].shape[0]),
|
|
121
|
+
"constant",
|
|
122
|
+
value=0,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
attention_mask = torch.full(
|
|
126
|
+
(1, patch_embed_seq.shape[-2]), fill_value=torch.finfo(patch_embed_seq.dtype).min
|
|
127
|
+
)
|
|
128
|
+
attention_mask[:, : h_patched_original * w_patched_original] = 0
|
|
129
|
+
if "out" in kwargs:
|
|
130
|
+
super().forward(patch_embed_seq, attention_mask, cos, sin, **kwargs)
|
|
131
|
+
transformer_output = kwargs["out"]
|
|
132
|
+
else:
|
|
133
|
+
transformer_output = super().forward(patch_embed_seq, attention_mask, cos, sin, **kwargs)
|
|
134
|
+
|
|
135
|
+
last_hidden_state_list.append(transformer_output[0][:, : h_patched_original * w_patched_original, :])
|
|
136
|
+
hidden_states = transformer_output[1:]
|
|
137
|
+
|
|
138
|
+
if output_hidden_states:
|
|
139
|
+
batch_hidden_states_list.append(
|
|
140
|
+
[hidden_state[:, : h_patched_original * w_patched_original, :] for hidden_state in hidden_states]
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
final_last_hidden_state = torch.cat(last_hidden_state_list, dim=1)
|
|
144
|
+
|
|
145
|
+
if output_hidden_states:
|
|
146
|
+
hidden_states = [
|
|
147
|
+
torch.cat(
|
|
148
|
+
[batch_hidden_states[layer_idx] for batch_hidden_states in batch_hidden_states_list],
|
|
149
|
+
dim=1,
|
|
150
|
+
)
|
|
151
|
+
for layer_idx in range(len(batch_hidden_states_list[0]))
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
final_hidden_states = tuple(hidden_states)
|
|
155
|
+
|
|
156
|
+
if not return_dict:
|
|
157
|
+
return tuple(v for v in (final_last_hidden_state, final_hidden_states) if v is not None)
|
|
158
|
+
|
|
159
|
+
# TODO: output_attentions
|
|
160
|
+
return BaseModelOutput(
|
|
161
|
+
last_hidden_state=final_last_hidden_state,
|
|
162
|
+
hidden_states=final_hidden_states,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class _PixtralVisionModel(torch.nn.Module):
|
|
167
|
+
def __init__(self, model: PixtralVisionModel, output_hidden_states: bool):
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.transformer = self.convert_to_rbln_pixtral_vision_model(model)
|
|
170
|
+
self.output_hidden_states = output_hidden_states
|
|
171
|
+
|
|
172
|
+
def convert_to_rbln_pixtral_vision_model(self, model: nn.Module):
|
|
173
|
+
for layer in model.transformer.layers:
|
|
174
|
+
layer.attention = PixtralAttention(layer.attention)
|
|
175
|
+
return model.transformer
|
|
176
|
+
|
|
177
|
+
def forward(self, patch_embeds, attention_mask, position_embeddings_1, position_embeddings_2):
|
|
178
|
+
output = self.transformer(
|
|
179
|
+
inputs_embeds=patch_embeds,
|
|
180
|
+
attention_mask=attention_mask,
|
|
181
|
+
position_embeddings=(position_embeddings_1, position_embeddings_2),
|
|
182
|
+
output_hidden_states=self.output_hidden_states,
|
|
183
|
+
return_dict=False,
|
|
184
|
+
)
|
|
185
|
+
return output
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class RBLNPixtralVisionModel(RBLNModel):
|
|
189
|
+
"""
|
|
190
|
+
RBLN optimized Pixtral vision encoder model.
|
|
191
|
+
|
|
192
|
+
This class provides hardware-accelerated inference for Pixtral vision encoders
|
|
193
|
+
on RBLN devices, supporting image encoding for multimodal tasks.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __post_init__(self, **kwargs):
|
|
197
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
198
|
+
with no_init_weights():
|
|
199
|
+
self.patch_conv = nn.Conv2d(
|
|
200
|
+
in_channels=self.config.num_channels,
|
|
201
|
+
out_channels=self.config.hidden_size,
|
|
202
|
+
kernel_size=self.config.patch_size,
|
|
203
|
+
stride=self.config.patch_size,
|
|
204
|
+
bias=False,
|
|
205
|
+
)
|
|
206
|
+
self.ln_pre = PixtralRMSNorm(self.config.hidden_size, eps=1e-5)
|
|
207
|
+
self.patch_conv.load_state_dict(artifacts["patch_conv"])
|
|
208
|
+
self.ln_pre.load_state_dict(artifacts["ln_pre"])
|
|
209
|
+
self.model = RBLNRuntimePixtralVisionModel(
|
|
210
|
+
self.model[0],
|
|
211
|
+
main_input_name="pixel_values",
|
|
212
|
+
config=self.config,
|
|
213
|
+
rbln_config=self.rbln_config,
|
|
214
|
+
patch_conv=self.patch_conv,
|
|
215
|
+
ln_pre=self.ln_pre,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def save_torch_artifacts(
|
|
220
|
+
cls,
|
|
221
|
+
model: "PreTrainedModel",
|
|
222
|
+
save_dir_path: Path,
|
|
223
|
+
subfolder: str,
|
|
224
|
+
rbln_config: RBLNModelConfig,
|
|
225
|
+
):
|
|
226
|
+
save_dict = {}
|
|
227
|
+
save_dict["patch_conv"] = model.get_input_embeddings().state_dict()
|
|
228
|
+
save_dict["ln_pre"] = model.ln_pre.state_dict()
|
|
229
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def _wrap_model_if_needed(
|
|
233
|
+
cls, model: torch.nn.Module, rbln_config: RBLNPixtralVisionModelConfig
|
|
234
|
+
) -> torch.nn.Module:
|
|
235
|
+
wrapper_cfg = {
|
|
236
|
+
"output_hidden_states": rbln_config.output_hidden_states,
|
|
237
|
+
}
|
|
238
|
+
return _PixtralVisionModel(model, **wrapper_cfg).eval()
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def _update_rbln_config(
|
|
242
|
+
cls,
|
|
243
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
244
|
+
model: Optional["PreTrainedModel"] = None,
|
|
245
|
+
model_config: "PixtralVisionConfig" = None,
|
|
246
|
+
rbln_config: Optional[RBLNPixtralVisionModelConfig] = None,
|
|
247
|
+
) -> RBLNPixtralVisionModelConfig:
|
|
248
|
+
if rbln_config.max_image_size is None:
|
|
249
|
+
rbln_config.max_image_size = (model_config.image_size, model_config.image_size)
|
|
250
|
+
|
|
251
|
+
if rbln_config.output_hidden_states is None:
|
|
252
|
+
rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
|
|
253
|
+
|
|
254
|
+
num_total_patches = (rbln_config.max_image_size[0] // model_config.patch_size) * (
|
|
255
|
+
rbln_config.max_image_size[1] // model_config.patch_size
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
rbln_compile_config = RBLNCompileConfig(
|
|
259
|
+
input_info=[
|
|
260
|
+
(
|
|
261
|
+
"patch_embeds",
|
|
262
|
+
[1, num_total_patches, model_config.hidden_size],
|
|
263
|
+
"float32",
|
|
264
|
+
),
|
|
265
|
+
("attention_mask", [1, num_total_patches], "float32"),
|
|
266
|
+
(
|
|
267
|
+
"position_embeddings_1",
|
|
268
|
+
[
|
|
269
|
+
num_total_patches,
|
|
270
|
+
model_config.head_dim,
|
|
271
|
+
],
|
|
272
|
+
"float32",
|
|
273
|
+
),
|
|
274
|
+
(
|
|
275
|
+
"position_embeddings_2",
|
|
276
|
+
[
|
|
277
|
+
num_total_patches,
|
|
278
|
+
model_config.head_dim,
|
|
279
|
+
],
|
|
280
|
+
"float32",
|
|
281
|
+
),
|
|
282
|
+
]
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
286
|
+
return rbln_config
|
|
287
|
+
|
|
288
|
+
def forward(
|
|
289
|
+
self,
|
|
290
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
291
|
+
image_sizes: Optional[torch.FloatTensor] = None,
|
|
292
|
+
output_hidden_states: Optional[bool] = None,
|
|
293
|
+
return_dict: bool = True,
|
|
294
|
+
**kwargs,
|
|
295
|
+
) -> Union[Tuple, BaseModelOutput]:
|
|
296
|
+
"""
|
|
297
|
+
Forward pass for the RBLN-optimized Pixtral vision model.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using PixtralImageProcessor. See PixtralImageProcessor.call() for details (PixtralProcessor uses PixtralImageProcessor for processing images).
|
|
301
|
+
image_sizes (torch.Tensor of shape (batch_size, 2), optional) — The sizes of the images in the batch, being (height, width) for each image.
|
|
302
|
+
output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
303
|
+
return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
BaseModelOutput or tuple(torch.FloatTensor)
|
|
307
|
+
"""
|
|
308
|
+
output_hidden_states = (
|
|
309
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
315
|
+
f"Please compile again with the correct argument."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
output = self.model(
|
|
319
|
+
pixel_values, image_sizes, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return output
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from ..decoderonly.decoderonly_architecture import apply_rotary_pos_emb
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PixtralAttention(nn.Module):
|
|
24
|
+
def __init__(self, self_attention):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.original_model = self_attention
|
|
27
|
+
self.num_heads = getattr(self.original_model, "num_heads", None) or getattr(
|
|
28
|
+
self.original_model.config, "num_attention_heads"
|
|
29
|
+
)
|
|
30
|
+
self.head_dim = self.original_model.head_dim
|
|
31
|
+
self.scaling = self.head_dim**-0.5
|
|
32
|
+
|
|
33
|
+
self.__post_init__()
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
self.q_proj = self.original_model.q_proj
|
|
37
|
+
self.k_proj = self.original_model.k_proj
|
|
38
|
+
self.v_proj = self.original_model.v_proj
|
|
39
|
+
self.o_proj = self.original_model.o_proj
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
hidden_states: torch.Tensor,
|
|
44
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
45
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
46
|
+
output_attentions: Optional[bool] = False,
|
|
47
|
+
):
|
|
48
|
+
batch_size, patches, _ = hidden_states.size()
|
|
49
|
+
|
|
50
|
+
query_states = self.q_proj(hidden_states)
|
|
51
|
+
key_states = self.k_proj(hidden_states)
|
|
52
|
+
value_states = self.v_proj(hidden_states)
|
|
53
|
+
|
|
54
|
+
# TODO: return output attention
|
|
55
|
+
query_states = query_states.view(batch_size, patches, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
|
56
|
+
key_states = key_states.view(batch_size, patches, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
|
57
|
+
value_states = value_states.view(batch_size, patches, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
|
58
|
+
|
|
59
|
+
cos, sin = position_embeddings
|
|
60
|
+
cos = cos[None, None, None, :, :]
|
|
61
|
+
sin = sin[None, None, None, :, :]
|
|
62
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
63
|
+
|
|
64
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) * self.scaling
|
|
65
|
+
attn_weights = attn_weights + attention_mask
|
|
66
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
|
67
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
68
|
+
attn_output = attn_output.transpose(1, 3)
|
|
69
|
+
|
|
70
|
+
attn_output = attn_output.reshape(batch_size, patches, -1)
|
|
71
|
+
attn_output = self.o_proj(attn_output)
|
|
72
|
+
|
|
73
|
+
return attn_output, _
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_qwen2 import RBLNQwen2ForCausalLMConfig
|
|
16
|
-
from .modeling_qwen2 import RBLNQwen2ForCausalLM
|
|
15
|
+
from .configuration_qwen2 import RBLNQwen2ForCausalLMConfig, RBLNQwen2ModelConfig
|
|
16
|
+
from .modeling_qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2Model
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RBLNQwen2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -40,3 +40,11 @@ class RBLNQwen2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
40
40
|
)
|
|
41
41
|
```
|
|
42
42
|
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNQwen2ModelConfig(RBLNDecoderOnlyModelConfig):
|
|
46
|
+
"""
|
|
47
|
+
Configuration class for RBLN Qwen2 models.
|
|
48
|
+
|
|
49
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
50
|
+
"""
|
|
@@ -15,7 +15,11 @@
|
|
|
15
15
|
from transformers import PretrainedConfig
|
|
16
16
|
|
|
17
17
|
from ....utils import logging
|
|
18
|
-
from ...models.decoderonly import
|
|
18
|
+
from ...models.decoderonly import (
|
|
19
|
+
RBLNDecoderOnlyModel,
|
|
20
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
21
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
22
|
+
)
|
|
19
23
|
from .qwen2_architecture import QWEN2Wrapper
|
|
20
24
|
|
|
21
25
|
|
|
@@ -95,3 +99,25 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
95
99
|
rbln_config.sliding_window = model_config.sliding_window
|
|
96
100
|
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
97
101
|
return rbln_config
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class RBLNQwen2Model(RBLNDecoderOnlyModel):
|
|
105
|
+
"""
|
|
106
|
+
The Qwen2 Model transformer without a language modeling head.
|
|
107
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
_decoder_wrapper_cls = QWEN2Wrapper
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def _update_sliding_window_config(
|
|
114
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
115
|
+
):
|
|
116
|
+
# https://github.com/huggingface/transformers/issues/35896
|
|
117
|
+
# There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
|
|
118
|
+
# we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
|
|
119
|
+
|
|
120
|
+
rbln_config.cache_impl = "sliding_window"
|
|
121
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
122
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
123
|
+
return rbln_config
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, List, Optional, Union
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
@@ -31,10 +31,22 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
|
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
visual: Optional[RBLNModelConfig] = None,
|
|
35
34
|
use_inputs_embeds: bool = True,
|
|
36
|
-
|
|
35
|
+
visual: Optional[RBLNModelConfig] = None,
|
|
36
|
+
**kwargs: Any,
|
|
37
37
|
):
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
use_inputs_embeds (bool): Whether or not to use `inputs_embeds` as input. Defaults to `True`.
|
|
41
|
+
visual (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
42
|
+
kwargs: Additional arguments passed to the parent `RBLNDecoderOnlyModelForCausalLMConfig`.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If `use_inputs_embeds` is False.
|
|
46
|
+
ValueError: If the visual configuration is provided but contains invalid settings, such as an invalid max_seq_lens (e.g., not a positive integer, not a multiple of the window-based attention unit, or insufficient for the expected resolution).
|
|
47
|
+
ValueError: If visual is None and no default vision configuration can be inferred for the model architecture.
|
|
48
|
+
ValueError: If any inherited parameters violate constraints defined in the parent class, such as batch_size not being a positive integer, prefill_chunk_size not being divisible by 64, or max_seq_len not meeting requirements for Flash Attention.
|
|
49
|
+
"""
|
|
38
50
|
super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
|
|
39
51
|
if not self.use_inputs_embeds:
|
|
40
52
|
raise ValueError(
|
|
@@ -53,7 +65,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
|
|
|
53
65
|
mechanisms for processing images and videos.
|
|
54
66
|
"""
|
|
55
67
|
|
|
56
|
-
def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs:
|
|
68
|
+
def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Any):
|
|
57
69
|
"""
|
|
58
70
|
Args:
|
|
59
71
|
max_seq_lens (Optional[Union[int, List[int]]]): Maximum sequence lengths for Vision
|
|
@@ -66,10 +78,13 @@ class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
|
|
|
66
78
|
making 256 (64 * 4) valid. RBLN optimization runs inference per image or video
|
|
67
79
|
frame, so set `max_seq_len` to match the maximum expected resolution to reduce
|
|
68
80
|
computation. If not provided, a `ValueError` is raised.
|
|
69
|
-
|
|
81
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
70
82
|
|
|
71
83
|
Raises:
|
|
72
|
-
ValueError: If
|
|
84
|
+
ValueError: If `max_seq_lens` is None or not provided.
|
|
85
|
+
ValueError: If `max_seq_lens` (or any value in the list) is not a positive integer.
|
|
86
|
+
ValueError: If `max_seq_lens` is not a multiple of (window_size / patch_size)^2 for window-based attention, or is insufficient for the expected image/video resolution.
|
|
87
|
+
ValueError: If `batch_size` (inherited from RBLNModelConfig) is not a positive integer.
|
|
73
88
|
|
|
74
89
|
Max Seq Lens:
|
|
75
90
|
Since `Qwen2_5_VLForConditionalGeneration` performs inference on a per-image or per-frame basis,
|
|
@@ -17,24 +17,21 @@ from pathlib import Path
|
|
|
17
17
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
-
from transformers import
|
|
21
|
-
AutoModelForVision2Seq,
|
|
22
|
-
PretrainedConfig,
|
|
23
|
-
PreTrainedModel,
|
|
24
|
-
Qwen2_5_VLForConditionalGeneration,
|
|
25
|
-
)
|
|
20
|
+
from transformers import AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, Qwen2_5_VLForConditionalGeneration
|
|
26
21
|
from transformers.modeling_utils import no_init_weights
|
|
27
22
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
28
23
|
Qwen2_5_VisionPatchEmbed,
|
|
29
24
|
Qwen2_5_VisionRotaryEmbedding,
|
|
30
25
|
Qwen2_5_VisionTransformerPretrainedModel,
|
|
26
|
+
Qwen2_5_VLModel,
|
|
31
27
|
Qwen2_5_VLRotaryEmbedding,
|
|
32
28
|
)
|
|
33
29
|
|
|
34
30
|
from ....configuration_utils import RBLNCompileConfig
|
|
35
31
|
from ....modeling import RBLNModel
|
|
36
32
|
from ....utils.logging import get_logger
|
|
37
|
-
from
|
|
33
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
34
|
+
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
38
35
|
from .configuration_qwen2_5_vl import (
|
|
39
36
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
40
37
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
@@ -45,12 +42,7 @@ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_V
|
|
|
45
42
|
logger = get_logger(__name__)
|
|
46
43
|
|
|
47
44
|
if TYPE_CHECKING:
|
|
48
|
-
from transformers import
|
|
49
|
-
AutoFeatureExtractor,
|
|
50
|
-
AutoProcessor,
|
|
51
|
-
AutoTokenizer,
|
|
52
|
-
PretrainedConfig,
|
|
53
|
-
)
|
|
45
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
54
46
|
|
|
55
47
|
|
|
56
48
|
class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
@@ -96,7 +88,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
|
96
88
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
97
89
|
|
|
98
90
|
@classmethod
|
|
99
|
-
def
|
|
91
|
+
def _wrap_model_if_needed(
|
|
100
92
|
cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
|
|
101
93
|
):
|
|
102
94
|
return Qwen2_5_VisionTransformerWrapper(model).eval()
|
|
@@ -381,6 +373,8 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
381
373
|
```
|
|
382
374
|
"""
|
|
383
375
|
|
|
376
|
+
_supports_non_fp32 = False
|
|
377
|
+
|
|
384
378
|
auto_model_class = AutoModelForVision2Seq
|
|
385
379
|
_rbln_submodules = [
|
|
386
380
|
{"name": "visual"},
|
|
@@ -399,13 +393,11 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
399
393
|
return True
|
|
400
394
|
|
|
401
395
|
@classmethod
|
|
402
|
-
def
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
)
|
|
408
|
-
return super().update_kwargs(kwargs)
|
|
396
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
397
|
+
model.model.lm_head = model.lm_head
|
|
398
|
+
model.lm_head = None
|
|
399
|
+
del model.lm_head
|
|
400
|
+
return model
|
|
409
401
|
|
|
410
402
|
@classmethod
|
|
411
403
|
def get_input_info(
|
|
@@ -539,7 +531,8 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
539
531
|
vision_tokens = input_id[0][vision_start_indices + 1]
|
|
540
532
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
541
533
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
542
|
-
position_ids, rope_deltas =
|
|
534
|
+
position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
|
|
535
|
+
self,
|
|
543
536
|
input_id,
|
|
544
537
|
image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
|
|
545
538
|
video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
|