optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -42,7 +42,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
|
|
|
42
42
|
interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
|
|
43
43
|
output_hidden_states: (Optional[bool]): Whether to return hidden states.
|
|
44
44
|
output_attentions: (Optional[bool]): Whether to return attentions.
|
|
45
|
-
|
|
45
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
46
46
|
|
|
47
47
|
Raises:
|
|
48
48
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any,
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import SiglipVisionConfig, SiglipVisionModel
|
|
@@ -29,8 +29,6 @@ logger = get_logger(__name__)
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
31
31
|
|
|
32
|
-
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
33
|
-
|
|
34
32
|
|
|
35
33
|
class _SiglipVisionModel(torch.nn.Module):
|
|
36
34
|
def __init__(
|
|
@@ -65,8 +63,12 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
65
63
|
on RBLN devices, supporting image encoding for multimodal vision-language tasks.
|
|
66
64
|
"""
|
|
67
65
|
|
|
66
|
+
_tp_support = False
|
|
67
|
+
|
|
68
68
|
@classmethod
|
|
69
|
-
def
|
|
69
|
+
def _wrap_model_if_needed(
|
|
70
|
+
cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
|
|
71
|
+
) -> torch.nn.Module:
|
|
70
72
|
wrapper_cfg = {
|
|
71
73
|
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
|
72
74
|
"output_hidden_states": rbln_config.output_hidden_states,
|
|
@@ -74,12 +76,6 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
74
76
|
}
|
|
75
77
|
return _SiglipVisionModel(model, **wrapper_cfg).eval()
|
|
76
78
|
|
|
77
|
-
@classmethod
|
|
78
|
-
def update_rbln_config_using_pipe(
|
|
79
|
-
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
80
|
-
) -> "RBLNDiffusionMixinConfig":
|
|
81
|
-
return rbln_config
|
|
82
|
-
|
|
83
79
|
@classmethod
|
|
84
80
|
def _update_rbln_config(
|
|
85
81
|
cls,
|
|
@@ -126,12 +122,21 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
126
122
|
output_attentions: bool = None,
|
|
127
123
|
output_hidden_states: bool = None,
|
|
128
124
|
interpolate_pos_encoding: bool = False,
|
|
129
|
-
**kwargs:
|
|
125
|
+
**kwargs: Any,
|
|
130
126
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
127
|
+
"""
|
|
128
|
+
Forward pass for the RBLN-optimized SigLIP vision model.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
|
|
132
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
133
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
|
|
134
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
135
|
+
interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
139
|
+
"""
|
|
135
140
|
|
|
136
141
|
output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
|
|
137
142
|
output_hidden_states = (
|
|
@@ -156,7 +161,7 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
156
161
|
f"Please compile again with the correct argument."
|
|
157
162
|
)
|
|
158
163
|
|
|
159
|
-
output = super().forward(pixel_values, return_dict=return_dict)
|
|
164
|
+
output = super().forward(pixel_values, return_dict=return_dict, **kwargs)
|
|
160
165
|
return output
|
|
161
166
|
|
|
162
167
|
def _prepare_output(self, output, return_dict):
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_swin import RBLNSwinBackboneConfig
|
|
16
|
+
from .modeling_swin import RBLNSwinBackbone
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at:
|
|
4
|
+
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from typing import Any, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
22
|
+
batch_size: Optional[int] = None,
|
|
23
|
+
output_hidden_states: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Args:
|
|
29
|
+
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
30
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If batch_size is not a positive integer.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
self.batch_size = batch_size or 1
|
|
37
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
38
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
39
|
+
|
|
40
|
+
self.image_size = image_size
|
|
41
|
+
self.output_hidden_states = output_hidden_states
|
|
42
|
+
self.output_attentions = output_attentions
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import types
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from transformers import SwinConfig
|
|
21
|
+
from transformers.models.swin.modeling_swin import BackboneOutput
|
|
22
|
+
|
|
23
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
24
|
+
from ....modeling import RBLNModel
|
|
25
|
+
from ....utils.logging import get_logger
|
|
26
|
+
from .configuration_swin import RBLNSwinBackboneConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers import (
|
|
33
|
+
AutoFeatureExtractor,
|
|
34
|
+
AutoProcessor,
|
|
35
|
+
AutoTokenizer,
|
|
36
|
+
PreTrainedModel,
|
|
37
|
+
SwinBackbone,
|
|
38
|
+
SwinEncoder,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def window_partition(input_feature, window_size):
|
|
43
|
+
"""
|
|
44
|
+
Partitions the given input into windows.
|
|
45
|
+
"""
|
|
46
|
+
batch_size, height, width, num_channels = input_feature.shape
|
|
47
|
+
input_feature = input_feature.view(
|
|
48
|
+
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
|
49
|
+
)
|
|
50
|
+
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
|
51
|
+
return windows
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_attn_mask(self, height, width, dtype, device):
|
|
55
|
+
if self.shift_size > 0:
|
|
56
|
+
# calculate attention mask for SW-MSA
|
|
57
|
+
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
|
58
|
+
height_slices = (
|
|
59
|
+
slice(0, -self.window_size),
|
|
60
|
+
slice(-self.window_size, -self.shift_size),
|
|
61
|
+
slice(-self.shift_size, None),
|
|
62
|
+
)
|
|
63
|
+
width_slices = (
|
|
64
|
+
slice(0, -self.window_size),
|
|
65
|
+
slice(-self.window_size, -self.shift_size),
|
|
66
|
+
slice(-self.shift_size, None),
|
|
67
|
+
)
|
|
68
|
+
count = torch.zeros(1)
|
|
69
|
+
for height_slice in height_slices:
|
|
70
|
+
for width_slice in width_slices:
|
|
71
|
+
img_mask[:, height_slice, width_slice, :] = count
|
|
72
|
+
count += 1
|
|
73
|
+
|
|
74
|
+
mask_windows = window_partition(img_mask, self.window_size)
|
|
75
|
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
|
76
|
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
77
|
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
|
78
|
+
else:
|
|
79
|
+
attn_mask = None
|
|
80
|
+
return attn_mask
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _SwinEncoder(torch.nn.Module):
|
|
84
|
+
def __init__(self, model: "SwinEncoder"):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.layers = model.layers
|
|
87
|
+
|
|
88
|
+
def forward(
|
|
89
|
+
self,
|
|
90
|
+
hidden_states: torch.Tensor,
|
|
91
|
+
input_dimensions: Tuple[int, int],
|
|
92
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
|
93
|
+
output_attentions: Optional[bool] = False,
|
|
94
|
+
output_hidden_states: Optional[bool] = False,
|
|
95
|
+
output_hidden_states_before_downsampling: Optional[bool] = False,
|
|
96
|
+
always_partition: Optional[bool] = False,
|
|
97
|
+
return_dict: Optional[bool] = True,
|
|
98
|
+
):
|
|
99
|
+
all_hidden_states = () if output_hidden_states else None
|
|
100
|
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
|
101
|
+
all_self_attentions = () if output_attentions else None
|
|
102
|
+
|
|
103
|
+
if output_hidden_states:
|
|
104
|
+
batch_size, _, hidden_size = hidden_states.shape
|
|
105
|
+
# rearrange b (h w) c -> b c h w
|
|
106
|
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
|
107
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
108
|
+
all_hidden_states += (hidden_states,)
|
|
109
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
110
|
+
|
|
111
|
+
for i, layer_module in enumerate(self.layers):
|
|
112
|
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
113
|
+
|
|
114
|
+
layer_outputs = layer_module(
|
|
115
|
+
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
hidden_states = layer_outputs[0]
|
|
119
|
+
hidden_states_before_downsampling = layer_outputs[1]
|
|
120
|
+
output_dimensions = layer_outputs[2]
|
|
121
|
+
|
|
122
|
+
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
|
123
|
+
|
|
124
|
+
if output_hidden_states and output_hidden_states_before_downsampling:
|
|
125
|
+
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
|
126
|
+
# rearrange b (h w) c -> b c h w
|
|
127
|
+
# here we use the original (not downsampled) height and width
|
|
128
|
+
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
|
129
|
+
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
|
130
|
+
)
|
|
131
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
132
|
+
all_hidden_states += (hidden_states_before_downsampling,)
|
|
133
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
134
|
+
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
|
135
|
+
batch_size, _, hidden_size = hidden_states.shape
|
|
136
|
+
# rearrange b (h w) c -> b c h w
|
|
137
|
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
|
138
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
139
|
+
all_hidden_states += (hidden_states,)
|
|
140
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
141
|
+
|
|
142
|
+
if output_attentions:
|
|
143
|
+
all_self_attentions += layer_outputs[3:]
|
|
144
|
+
|
|
145
|
+
return tuple(
|
|
146
|
+
v
|
|
147
|
+
for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
|
|
148
|
+
if v is not None
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class _SwinBackbone(torch.nn.Module):
|
|
153
|
+
def __init__(self, model: "SwinBackbone", output_hidden_states: bool, output_attentions: bool):
|
|
154
|
+
super().__init__()
|
|
155
|
+
self.model = model
|
|
156
|
+
self.embeddings = model.embeddings
|
|
157
|
+
self.encoder = model.encoder
|
|
158
|
+
self.stage_names = model.stage_names
|
|
159
|
+
self.out_features = model.out_features
|
|
160
|
+
self.hidden_states_norms = model.hidden_states_norms
|
|
161
|
+
self.output_hidden_states = output_hidden_states
|
|
162
|
+
self.output_attentions = output_attentions
|
|
163
|
+
|
|
164
|
+
def forward(
|
|
165
|
+
self,
|
|
166
|
+
pixel_values: torch.Tensor,
|
|
167
|
+
):
|
|
168
|
+
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
|
169
|
+
outputs = _SwinEncoder(self.encoder)(
|
|
170
|
+
embedding_output,
|
|
171
|
+
input_dimensions,
|
|
172
|
+
head_mask=None,
|
|
173
|
+
output_attentions=self.output_attentions,
|
|
174
|
+
output_hidden_states=True,
|
|
175
|
+
output_hidden_states_before_downsampling=True,
|
|
176
|
+
always_partition=True,
|
|
177
|
+
return_dict=False,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
hidden_states = outputs[-1]
|
|
181
|
+
|
|
182
|
+
feature_maps = ()
|
|
183
|
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
|
184
|
+
if stage in self.out_features:
|
|
185
|
+
batch_size, num_channels, height, width = hidden_state.shape
|
|
186
|
+
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
|
187
|
+
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
|
188
|
+
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
|
189
|
+
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
|
190
|
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
|
191
|
+
feature_maps += (hidden_state,)
|
|
192
|
+
|
|
193
|
+
output = (feature_maps,)
|
|
194
|
+
|
|
195
|
+
if self.output_hidden_states:
|
|
196
|
+
output += (outputs[1],)
|
|
197
|
+
|
|
198
|
+
if self.output_attentions:
|
|
199
|
+
output += (outputs[2],)
|
|
200
|
+
|
|
201
|
+
return output
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class RBLNSwinBackbone(RBLNModel):
|
|
205
|
+
@classmethod
|
|
206
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
|
|
207
|
+
for layer in model.encoder.layers:
|
|
208
|
+
for block in layer.blocks:
|
|
209
|
+
block.get_attn_mask = types.MethodType(get_attn_mask, block)
|
|
210
|
+
|
|
211
|
+
wrapper_cfg = {
|
|
212
|
+
"output_hidden_states": rbln_config.output_hidden_states,
|
|
213
|
+
"output_attentions": rbln_config.output_attentions,
|
|
214
|
+
}
|
|
215
|
+
return _SwinBackbone(model, **wrapper_cfg).eval()
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def _update_submodule_config(
|
|
219
|
+
cls,
|
|
220
|
+
model: "PreTrainedModel",
|
|
221
|
+
rbln_config: RBLNModelConfig,
|
|
222
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
223
|
+
):
|
|
224
|
+
for processor in preprocessors:
|
|
225
|
+
if rbln_config.image_size is None and hasattr(processor, "image_processor"):
|
|
226
|
+
if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
|
|
227
|
+
rbln_config.image_size = (
|
|
228
|
+
processor.image_processor.size["height"],
|
|
229
|
+
processor.image_processor.size["width"],
|
|
230
|
+
)
|
|
231
|
+
elif (
|
|
232
|
+
"longest_edge" in processor.image_processor.size
|
|
233
|
+
and "shortest_edge" in processor.image_processor.size
|
|
234
|
+
):
|
|
235
|
+
rbln_config.image_size = processor.image_processor.size["longest_edge"]
|
|
236
|
+
elif "shortest_edge" in processor.image_processor.size:
|
|
237
|
+
rbln_config.image_size = processor.image_processor.size["shortest_edge"]
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
return rbln_config
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def _update_rbln_config(
|
|
244
|
+
cls,
|
|
245
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
246
|
+
model: Optional["PreTrainedModel"] = None,
|
|
247
|
+
model_config: "SwinConfig" = None,
|
|
248
|
+
rbln_config: Optional[RBLNSwinBackboneConfig] = None,
|
|
249
|
+
) -> RBLNSwinBackboneConfig:
|
|
250
|
+
if rbln_config.image_size is None:
|
|
251
|
+
for processor in preprocessors:
|
|
252
|
+
if hasattr(processor, "size"):
|
|
253
|
+
if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
|
|
254
|
+
rbln_config.image_size = (processor.size["height"], processor.size["width"])
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
input_info = [
|
|
258
|
+
(
|
|
259
|
+
"pixel_values",
|
|
260
|
+
[
|
|
261
|
+
rbln_config.batch_size,
|
|
262
|
+
3,
|
|
263
|
+
rbln_config.image_height,
|
|
264
|
+
rbln_config.image_width,
|
|
265
|
+
],
|
|
266
|
+
"float32",
|
|
267
|
+
),
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
271
|
+
return rbln_config
|
|
272
|
+
|
|
273
|
+
def forward(
|
|
274
|
+
self,
|
|
275
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
276
|
+
return_dict: bool = True,
|
|
277
|
+
output_attentions: bool = None,
|
|
278
|
+
output_hidden_states: bool = None,
|
|
279
|
+
**kwargs,
|
|
280
|
+
) -> Union[Tuple, BackboneOutput]:
|
|
281
|
+
"""
|
|
282
|
+
Forward pass for the RBLN-optimized Swin backbone model.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
|
|
286
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
287
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
|
|
288
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BackboneOutput object.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
|
|
295
|
+
logger.warning(
|
|
296
|
+
f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
|
|
300
|
+
output_hidden_states = (
|
|
301
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if output_attentions != self.rbln_config.output_attentions:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
307
|
+
f"Please compile again with the correct argument."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
313
|
+
f"Please compile again with the correct argument."
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
_, _, original_h, original_w = pixel_values.shape
|
|
317
|
+
if original_h > self.rbln_config.image_height or original_w > self.rbln_config.image_width:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Input image size ({original_h}x{original_w}) exceeds the configured maximum size"
|
|
320
|
+
f" ({self.rbln_config.image_height}x{self.rbln_config.image_width})."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
pad_h = self.rbln_config.image_height - original_h
|
|
324
|
+
pad_w = self.rbln_config.image_width - original_w
|
|
325
|
+
padded_pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
|
|
326
|
+
|
|
327
|
+
output = self.model[0](padded_pixel_values)
|
|
328
|
+
|
|
329
|
+
feature_maps = ()
|
|
330
|
+
for i in range(len(self.config.out_features)):
|
|
331
|
+
feature_maps += (output.pop(0),)
|
|
332
|
+
|
|
333
|
+
if self.rbln_config.output_hidden_states:
|
|
334
|
+
hidden_states = ()
|
|
335
|
+
for i in range(len(self.config.stage_names)):
|
|
336
|
+
hidden_states += (output.pop(0),)
|
|
337
|
+
else:
|
|
338
|
+
hidden_states = None
|
|
339
|
+
|
|
340
|
+
if self.rbln_config.output_attentions:
|
|
341
|
+
attentions = ()
|
|
342
|
+
for i in range(len(self.config.depths)):
|
|
343
|
+
attentions += (output.pop(0),)
|
|
344
|
+
else:
|
|
345
|
+
attentions = None
|
|
346
|
+
|
|
347
|
+
if not return_dict:
|
|
348
|
+
return tuple(item for item in (feature_maps, hidden_states, attentions) if item is not None)
|
|
349
|
+
else:
|
|
350
|
+
return BackboneOutput(
|
|
351
|
+
feature_maps=feature_maps,
|
|
352
|
+
hidden_states=hidden_states,
|
|
353
|
+
attentions=attentions,
|
|
354
|
+
)
|
|
@@ -68,7 +68,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
68
68
|
output_class = BaseModelOutputWithPastAndCrossAttentions
|
|
69
69
|
|
|
70
70
|
@classmethod
|
|
71
|
-
def
|
|
71
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
|
|
72
72
|
return T5EncoderWrapper(model)
|
|
73
73
|
|
|
74
74
|
@classmethod
|
|
@@ -113,7 +113,7 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
|
113
113
|
support_causal_attn = False
|
|
114
114
|
|
|
115
115
|
@classmethod
|
|
116
|
-
def
|
|
116
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
|
|
117
117
|
return T5Wrapper(
|
|
118
118
|
model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
|
|
119
119
|
)
|
|
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
|
|
|
126
126
|
b_size = attention_mask.shape[0]
|
|
127
127
|
batch_decoder_position_bias = []
|
|
128
128
|
for i in range(b_size):
|
|
129
|
-
|
|
129
|
+
if torch.compiler.is_exporting():
|
|
130
|
+
cache_pos = cache_position[i][0].item()
|
|
131
|
+
torch._check_is_size(cache_pos)
|
|
132
|
+
torch._check(cache_pos >= 0)
|
|
133
|
+
torch._check(cache_pos < self._dec_position_bias.shape[2])
|
|
134
|
+
else:
|
|
135
|
+
cache_pos = cache_position[i][0]
|
|
136
|
+
batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
|
|
130
137
|
batch_decoder_position_bias.append(batch_position_bias)
|
|
131
138
|
position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
|
132
139
|
|
optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
from ....configuration_utils import RBLNModelConfig
|
|
4
4
|
|
|
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
|
|
17
17
|
enc_max_seq_len: Optional[int] = None,
|
|
18
18
|
dec_max_seq_len: Optional[int] = None,
|
|
19
19
|
num_parallel_samples: Optional[int] = None,
|
|
20
|
-
**kwargs:
|
|
20
|
+
**kwargs: Any,
|
|
21
21
|
):
|
|
22
22
|
"""
|
|
23
23
|
Args:
|
|
@@ -25,7 +25,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
|
|
25
25
|
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
26
26
|
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
27
27
|
num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
|
|
28
|
-
|
|
28
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
29
29
|
|
|
30
30
|
Raises:
|
|
31
31
|
ValueError: If batch_size is not a positive integer.
|
optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py
CHANGED
|
@@ -23,24 +23,20 @@
|
|
|
23
23
|
|
|
24
24
|
import inspect
|
|
25
25
|
import logging
|
|
26
|
-
from dataclasses import dataclass
|
|
27
26
|
from pathlib import Path
|
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional,
|
|
27
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
|
29
28
|
|
|
30
29
|
import rebel
|
|
31
30
|
import torch
|
|
32
31
|
from rebel.compile_context import CompileContext
|
|
33
|
-
from transformers import
|
|
34
|
-
|
|
35
|
-
TimeSeriesTransformerForPrediction,
|
|
36
|
-
TimeSeriesTransformerModel,
|
|
37
|
-
)
|
|
38
|
-
from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
|
32
|
+
from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
|
|
33
|
+
from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
|
39
34
|
from transformers.modeling_utils import no_init_weights
|
|
40
35
|
|
|
41
36
|
from ....configuration_utils import RBLNCompileConfig
|
|
42
37
|
from ....modeling import RBLNModel
|
|
43
38
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
39
|
+
from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
|
|
44
40
|
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
|
45
41
|
from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
|
|
46
42
|
|
|
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
113
109
|
)
|
|
114
110
|
|
|
115
111
|
|
|
116
|
-
@dataclass
|
|
117
|
-
class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
|
|
118
|
-
last_hidden_states: torch.FloatTensor = None
|
|
119
|
-
params: Tuple[torch.FloatTensor] = None
|
|
120
|
-
|
|
121
|
-
|
|
122
112
|
class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
123
113
|
"""
|
|
124
114
|
The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
|
|
@@ -163,7 +153,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
163
153
|
return redirect(val)
|
|
164
154
|
|
|
165
155
|
@classmethod
|
|
166
|
-
def
|
|
156
|
+
def _wrap_model_if_needed(
|
|
167
157
|
self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
|
|
168
158
|
):
|
|
169
159
|
return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
|
|
@@ -171,7 +161,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
171
161
|
@classmethod
|
|
172
162
|
@torch.inference_mode()
|
|
173
163
|
def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
|
|
174
|
-
wrapped_model = cls.
|
|
164
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
175
165
|
|
|
176
166
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
177
167
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -363,6 +353,20 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
363
353
|
static_real_features: Optional[torch.Tensor] = None,
|
|
364
354
|
**kwargs,
|
|
365
355
|
) -> SampleTSPredictionOutput:
|
|
356
|
+
"""
|
|
357
|
+
Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
|
|
361
|
+
past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
|
|
362
|
+
future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
|
|
363
|
+
past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
|
|
364
|
+
static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
|
|
365
|
+
static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
|
|
369
|
+
"""
|
|
366
370
|
self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
|
|
367
371
|
|
|
368
372
|
outputs = self.encoder(
|
optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py
CHANGED
|
@@ -162,7 +162,13 @@ class TimeSeriesTransformersDecoder(nn.Module):
|
|
|
162
162
|
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
|
163
163
|
|
|
164
164
|
hidden_states = self.value_embedding(inputs_embeds)
|
|
165
|
-
|
|
165
|
+
embed_idx = cache_position + self.config.context_length
|
|
166
|
+
if torch.compiler.is_exporting():
|
|
167
|
+
embed_idx = embed_idx.item()
|
|
168
|
+
torch._check_is_size(embed_idx)
|
|
169
|
+
torch._check(embed_idx >= 0)
|
|
170
|
+
torch._check(embed_idx < len(self.embed_positions.weight))
|
|
171
|
+
embed_pos = self.embed_positions.weight[embed_idx]
|
|
166
172
|
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
|
|
167
173
|
|
|
168
174
|
# iterate decoder_layer
|