optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import types
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from transformers import SwinConfig
|
|
21
|
+
from transformers.models.swin.modeling_swin import BackboneOutput
|
|
22
|
+
|
|
23
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
24
|
+
from ....modeling import RBLNModel
|
|
25
|
+
from ....utils.logging import get_logger
|
|
26
|
+
from .configuration_swin import RBLNSwinBackboneConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers import (
|
|
33
|
+
AutoFeatureExtractor,
|
|
34
|
+
AutoProcessor,
|
|
35
|
+
AutoTokenizer,
|
|
36
|
+
PreTrainedModel,
|
|
37
|
+
SwinBackbone,
|
|
38
|
+
SwinEncoder,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def window_partition(input_feature, window_size):
|
|
43
|
+
"""
|
|
44
|
+
Partitions the given input into windows.
|
|
45
|
+
"""
|
|
46
|
+
batch_size, height, width, num_channels = input_feature.shape
|
|
47
|
+
input_feature = input_feature.view(
|
|
48
|
+
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
|
49
|
+
)
|
|
50
|
+
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
|
51
|
+
return windows
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_attn_mask(self, height, width, dtype, device):
|
|
55
|
+
if self.shift_size > 0:
|
|
56
|
+
# calculate attention mask for SW-MSA
|
|
57
|
+
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
|
58
|
+
height_slices = (
|
|
59
|
+
slice(0, -self.window_size),
|
|
60
|
+
slice(-self.window_size, -self.shift_size),
|
|
61
|
+
slice(-self.shift_size, None),
|
|
62
|
+
)
|
|
63
|
+
width_slices = (
|
|
64
|
+
slice(0, -self.window_size),
|
|
65
|
+
slice(-self.window_size, -self.shift_size),
|
|
66
|
+
slice(-self.shift_size, None),
|
|
67
|
+
)
|
|
68
|
+
count = torch.zeros(1)
|
|
69
|
+
for height_slice in height_slices:
|
|
70
|
+
for width_slice in width_slices:
|
|
71
|
+
img_mask[:, height_slice, width_slice, :] = count
|
|
72
|
+
count += 1
|
|
73
|
+
|
|
74
|
+
mask_windows = window_partition(img_mask, self.window_size)
|
|
75
|
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
|
76
|
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
77
|
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
|
78
|
+
else:
|
|
79
|
+
attn_mask = None
|
|
80
|
+
return attn_mask
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _SwinEncoder(torch.nn.Module):
|
|
84
|
+
def __init__(self, model: "SwinEncoder"):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.layers = model.layers
|
|
87
|
+
|
|
88
|
+
def forward(
|
|
89
|
+
self,
|
|
90
|
+
hidden_states: torch.Tensor,
|
|
91
|
+
input_dimensions: Tuple[int, int],
|
|
92
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
|
93
|
+
output_attentions: Optional[bool] = False,
|
|
94
|
+
output_hidden_states: Optional[bool] = False,
|
|
95
|
+
output_hidden_states_before_downsampling: Optional[bool] = False,
|
|
96
|
+
always_partition: Optional[bool] = False,
|
|
97
|
+
return_dict: Optional[bool] = True,
|
|
98
|
+
):
|
|
99
|
+
all_hidden_states = () if output_hidden_states else None
|
|
100
|
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
|
101
|
+
all_self_attentions = () if output_attentions else None
|
|
102
|
+
|
|
103
|
+
if output_hidden_states:
|
|
104
|
+
batch_size, _, hidden_size = hidden_states.shape
|
|
105
|
+
# rearrange b (h w) c -> b c h w
|
|
106
|
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
|
107
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
108
|
+
all_hidden_states += (hidden_states,)
|
|
109
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
110
|
+
|
|
111
|
+
for i, layer_module in enumerate(self.layers):
|
|
112
|
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
113
|
+
|
|
114
|
+
layer_outputs = layer_module(
|
|
115
|
+
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
hidden_states = layer_outputs[0]
|
|
119
|
+
hidden_states_before_downsampling = layer_outputs[1]
|
|
120
|
+
output_dimensions = layer_outputs[2]
|
|
121
|
+
|
|
122
|
+
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
|
123
|
+
|
|
124
|
+
if output_hidden_states and output_hidden_states_before_downsampling:
|
|
125
|
+
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
|
126
|
+
# rearrange b (h w) c -> b c h w
|
|
127
|
+
# here we use the original (not downsampled) height and width
|
|
128
|
+
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
|
129
|
+
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
|
130
|
+
)
|
|
131
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
132
|
+
all_hidden_states += (hidden_states_before_downsampling,)
|
|
133
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
134
|
+
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
|
135
|
+
batch_size, _, hidden_size = hidden_states.shape
|
|
136
|
+
# rearrange b (h w) c -> b c h w
|
|
137
|
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
|
138
|
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
|
139
|
+
all_hidden_states += (hidden_states,)
|
|
140
|
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
|
141
|
+
|
|
142
|
+
if output_attentions:
|
|
143
|
+
all_self_attentions += layer_outputs[3:]
|
|
144
|
+
|
|
145
|
+
return tuple(
|
|
146
|
+
v
|
|
147
|
+
for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
|
|
148
|
+
if v is not None
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class _SwinBackbone(torch.nn.Module):
|
|
153
|
+
def __init__(self, model: "SwinBackbone", output_hidden_states: bool, output_attentions: bool):
|
|
154
|
+
super().__init__()
|
|
155
|
+
self.model = model
|
|
156
|
+
self.embeddings = model.embeddings
|
|
157
|
+
self.encoder = model.encoder
|
|
158
|
+
self.stage_names = model.stage_names
|
|
159
|
+
self.out_features = model.out_features
|
|
160
|
+
self.hidden_states_norms = model.hidden_states_norms
|
|
161
|
+
self.output_hidden_states = output_hidden_states
|
|
162
|
+
self.output_attentions = output_attentions
|
|
163
|
+
|
|
164
|
+
def forward(
|
|
165
|
+
self,
|
|
166
|
+
pixel_values: torch.Tensor,
|
|
167
|
+
):
|
|
168
|
+
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
|
169
|
+
outputs = _SwinEncoder(self.encoder)(
|
|
170
|
+
embedding_output,
|
|
171
|
+
input_dimensions,
|
|
172
|
+
head_mask=None,
|
|
173
|
+
output_attentions=self.output_attentions,
|
|
174
|
+
output_hidden_states=True,
|
|
175
|
+
output_hidden_states_before_downsampling=True,
|
|
176
|
+
always_partition=True,
|
|
177
|
+
return_dict=False,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
hidden_states = outputs[-1]
|
|
181
|
+
|
|
182
|
+
feature_maps = ()
|
|
183
|
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
|
184
|
+
if stage in self.out_features:
|
|
185
|
+
batch_size, num_channels, height, width = hidden_state.shape
|
|
186
|
+
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
|
187
|
+
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
|
188
|
+
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
|
189
|
+
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
|
190
|
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
|
191
|
+
feature_maps += (hidden_state,)
|
|
192
|
+
|
|
193
|
+
output = (feature_maps,)
|
|
194
|
+
|
|
195
|
+
if self.output_hidden_states:
|
|
196
|
+
output += (outputs[1],)
|
|
197
|
+
|
|
198
|
+
if self.output_attentions:
|
|
199
|
+
output += (outputs[2],)
|
|
200
|
+
|
|
201
|
+
return output
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class RBLNSwinBackbone(RBLNModel):
|
|
205
|
+
@classmethod
|
|
206
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
|
|
207
|
+
for layer in model.encoder.layers:
|
|
208
|
+
for block in layer.blocks:
|
|
209
|
+
block.get_attn_mask = types.MethodType(get_attn_mask, block)
|
|
210
|
+
|
|
211
|
+
wrapper_cfg = {
|
|
212
|
+
"output_hidden_states": rbln_config.output_hidden_states,
|
|
213
|
+
"output_attentions": rbln_config.output_attentions,
|
|
214
|
+
}
|
|
215
|
+
return _SwinBackbone(model, **wrapper_cfg).eval()
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def _update_submodule_config(
|
|
219
|
+
cls,
|
|
220
|
+
model: "PreTrainedModel",
|
|
221
|
+
rbln_config: RBLNModelConfig,
|
|
222
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
223
|
+
):
|
|
224
|
+
for processor in preprocessors:
|
|
225
|
+
if rbln_config.image_size is None and hasattr(processor, "image_processor"):
|
|
226
|
+
if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
|
|
227
|
+
rbln_config.image_size = (
|
|
228
|
+
processor.image_processor.size["height"],
|
|
229
|
+
processor.image_processor.size["width"],
|
|
230
|
+
)
|
|
231
|
+
elif (
|
|
232
|
+
"longest_edge" in processor.image_processor.size
|
|
233
|
+
and "shortest_edge" in processor.image_processor.size
|
|
234
|
+
):
|
|
235
|
+
rbln_config.image_size = processor.image_processor.size["longest_edge"]
|
|
236
|
+
elif "shortest_edge" in processor.image_processor.size:
|
|
237
|
+
rbln_config.image_size = processor.image_processor.size["shortest_edge"]
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
return rbln_config
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def _update_rbln_config(
|
|
244
|
+
cls,
|
|
245
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
246
|
+
model: Optional["PreTrainedModel"] = None,
|
|
247
|
+
model_config: "SwinConfig" = None,
|
|
248
|
+
rbln_config: Optional[RBLNSwinBackboneConfig] = None,
|
|
249
|
+
) -> RBLNSwinBackboneConfig:
|
|
250
|
+
if rbln_config.image_size is None:
|
|
251
|
+
for processor in preprocessors:
|
|
252
|
+
if hasattr(processor, "size"):
|
|
253
|
+
if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
|
|
254
|
+
rbln_config.image_size = (processor.size["height"], processor.size["width"])
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
input_info = [
|
|
258
|
+
(
|
|
259
|
+
"pixel_values",
|
|
260
|
+
[
|
|
261
|
+
rbln_config.batch_size,
|
|
262
|
+
3,
|
|
263
|
+
rbln_config.image_height,
|
|
264
|
+
rbln_config.image_width,
|
|
265
|
+
],
|
|
266
|
+
"float32",
|
|
267
|
+
),
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
271
|
+
return rbln_config
|
|
272
|
+
|
|
273
|
+
def forward(
|
|
274
|
+
self,
|
|
275
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
276
|
+
return_dict: bool = True,
|
|
277
|
+
output_attentions: bool = None,
|
|
278
|
+
output_hidden_states: bool = None,
|
|
279
|
+
**kwargs,
|
|
280
|
+
) -> Union[Tuple, BackboneOutput]:
|
|
281
|
+
if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
|
|
282
|
+
logger.warning(
|
|
283
|
+
f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
|
|
287
|
+
output_hidden_states = (
|
|
288
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
if output_attentions != self.rbln_config.output_attentions:
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
294
|
+
f"Please compile again with the correct argument."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
300
|
+
f"Please compile again with the correct argument."
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
_, _, original_h, original_w = pixel_values.shape
|
|
304
|
+
if original_h > self.rbln_config.image_height or original_w > self.rbln_config.image_width:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Input image size ({original_h}x{original_w}) exceeds the configured maximum size"
|
|
307
|
+
f" ({self.rbln_config.image_height}x{self.rbln_config.image_width})."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
pad_h = self.rbln_config.image_height - original_h
|
|
311
|
+
pad_w = self.rbln_config.image_width - original_w
|
|
312
|
+
padded_pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
|
|
313
|
+
|
|
314
|
+
output = self.model[0](padded_pixel_values)
|
|
315
|
+
|
|
316
|
+
feature_maps = ()
|
|
317
|
+
for i in range(len(self.config.out_features)):
|
|
318
|
+
feature_maps += (output.pop(0),)
|
|
319
|
+
|
|
320
|
+
if self.rbln_config.output_hidden_states:
|
|
321
|
+
hidden_states = ()
|
|
322
|
+
for i in range(len(self.config.stage_names)):
|
|
323
|
+
hidden_states += (output.pop(0),)
|
|
324
|
+
else:
|
|
325
|
+
hidden_states = None
|
|
326
|
+
|
|
327
|
+
if self.rbln_config.output_attentions:
|
|
328
|
+
attentions = ()
|
|
329
|
+
for i in range(len(self.config.depths)):
|
|
330
|
+
attentions += (output.pop(0),)
|
|
331
|
+
else:
|
|
332
|
+
attentions = None
|
|
333
|
+
|
|
334
|
+
if not return_dict:
|
|
335
|
+
return tuple(item for item in (feature_maps, hidden_states, attentions) if item is not None)
|
|
336
|
+
else:
|
|
337
|
+
return BackboneOutput(
|
|
338
|
+
feature_maps=feature_maps,
|
|
339
|
+
hidden_states=hidden_states,
|
|
340
|
+
attentions=attentions,
|
|
341
|
+
)
|
|
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
|
|
|
126
126
|
b_size = attention_mask.shape[0]
|
|
127
127
|
batch_decoder_position_bias = []
|
|
128
128
|
for i in range(b_size):
|
|
129
|
-
|
|
129
|
+
if torch.compiler.is_exporting():
|
|
130
|
+
cache_pos = cache_position[i][0].item()
|
|
131
|
+
torch._check_is_size(cache_pos)
|
|
132
|
+
torch._check(cache_pos >= 0)
|
|
133
|
+
torch._check(cache_pos < self._dec_position_bias.shape[2])
|
|
134
|
+
else:
|
|
135
|
+
cache_pos = cache_position[i][0]
|
|
136
|
+
batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
|
|
130
137
|
batch_decoder_position_bias.append(batch_position_bias)
|
|
131
138
|
position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
|
132
139
|
|
optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
from ....configuration_utils import RBLNModelConfig
|
|
4
4
|
|
|
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
|
|
17
17
|
enc_max_seq_len: Optional[int] = None,
|
|
18
18
|
dec_max_seq_len: Optional[int] = None,
|
|
19
19
|
num_parallel_samples: Optional[int] = None,
|
|
20
|
-
**kwargs:
|
|
20
|
+
**kwargs: Any,
|
|
21
21
|
):
|
|
22
22
|
"""
|
|
23
23
|
Args:
|
|
@@ -25,7 +25,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
|
|
25
25
|
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
26
26
|
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
27
27
|
num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
|
|
28
|
-
|
|
28
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
29
29
|
|
|
30
30
|
Raises:
|
|
31
31
|
ValueError: If batch_size is not a positive integer.
|
optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py
CHANGED
|
@@ -23,24 +23,20 @@
|
|
|
23
23
|
|
|
24
24
|
import inspect
|
|
25
25
|
import logging
|
|
26
|
-
from dataclasses import dataclass
|
|
27
26
|
from pathlib import Path
|
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional,
|
|
27
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
|
29
28
|
|
|
30
29
|
import rebel
|
|
31
30
|
import torch
|
|
32
31
|
from rebel.compile_context import CompileContext
|
|
33
|
-
from transformers import
|
|
34
|
-
|
|
35
|
-
TimeSeriesTransformerForPrediction,
|
|
36
|
-
TimeSeriesTransformerModel,
|
|
37
|
-
)
|
|
38
|
-
from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
|
32
|
+
from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
|
|
33
|
+
from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
|
39
34
|
from transformers.modeling_utils import no_init_weights
|
|
40
35
|
|
|
41
36
|
from ....configuration_utils import RBLNCompileConfig
|
|
42
37
|
from ....modeling import RBLNModel
|
|
43
38
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
39
|
+
from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
|
|
44
40
|
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
|
45
41
|
from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
|
|
46
42
|
|
|
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
113
109
|
)
|
|
114
110
|
|
|
115
111
|
|
|
116
|
-
@dataclass
|
|
117
|
-
class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
|
|
118
|
-
last_hidden_states: torch.FloatTensor = None
|
|
119
|
-
params: Tuple[torch.FloatTensor] = None
|
|
120
|
-
|
|
121
|
-
|
|
122
112
|
class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
123
113
|
"""
|
|
124
114
|
The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
|
optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py
CHANGED
|
@@ -162,7 +162,13 @@ class TimeSeriesTransformersDecoder(nn.Module):
|
|
|
162
162
|
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
|
163
163
|
|
|
164
164
|
hidden_states = self.value_embedding(inputs_embeds)
|
|
165
|
-
|
|
165
|
+
embed_idx = cache_position + self.config.context_length
|
|
166
|
+
if torch.compiler.is_exporting():
|
|
167
|
+
embed_idx = embed_idx.item()
|
|
168
|
+
torch._check_is_size(embed_idx)
|
|
169
|
+
torch._check(embed_idx >= 0)
|
|
170
|
+
torch._check(embed_idx < len(self.embed_positions.weight))
|
|
171
|
+
embed_pos = self.embed_positions.weight[embed_idx]
|
|
166
172
|
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
|
|
167
173
|
|
|
168
174
|
# iterate decoder_layer
|
|
@@ -38,6 +38,7 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
|
|
|
38
38
|
library implements for all its model.
|
|
39
39
|
|
|
40
40
|
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
|
41
|
+
|
|
41
42
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
42
43
|
- compiling the resulting graph using the RBLN compiler.
|
|
43
44
|
"""
|
|
@@ -12,9 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any
|
|
16
|
-
|
|
17
|
-
import rebel
|
|
15
|
+
from typing import Any
|
|
18
16
|
|
|
19
17
|
from ....configuration_utils import RBLNModelConfig
|
|
20
18
|
from ....utils.logging import get_logger
|
|
@@ -38,17 +36,22 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
38
36
|
use_attention_mask: bool = None,
|
|
39
37
|
enc_max_seq_len: int = None,
|
|
40
38
|
dec_max_seq_len: int = None,
|
|
41
|
-
|
|
39
|
+
kvcache_num_blocks: int = None,
|
|
40
|
+
kvcache_block_size: int = None,
|
|
41
|
+
**kwargs: Any,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
44
44
|
Args:
|
|
45
45
|
batch_size (int, optional): The batch size for inference. Defaults to 1.
|
|
46
46
|
token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
|
|
47
47
|
use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
|
|
48
|
-
set to True for RBLN-CA02 devices.
|
|
49
48
|
enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
|
|
50
49
|
dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
|
|
51
|
-
|
|
50
|
+
kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
|
|
51
|
+
PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
|
|
52
|
+
kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
|
|
53
|
+
in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
|
|
54
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
52
55
|
|
|
53
56
|
Raises:
|
|
54
57
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -64,10 +67,6 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
64
67
|
self.dec_max_seq_len = dec_max_seq_len
|
|
65
68
|
|
|
66
69
|
self.use_attention_mask = use_attention_mask
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
|
71
|
-
self.use_attention_mask = True
|
|
72
|
-
else:
|
|
73
|
-
self.use_attention_mask = self.use_attention_mask or False
|
|
70
|
+
self.use_attention_mask = self.use_attention_mask or False
|
|
71
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
|
72
|
+
self.kvcache_block_size = kvcache_block_size
|
|
@@ -39,14 +39,31 @@ from transformers.models.whisper.generation_whisper import WhisperGenerationMixi
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
42
|
+
def generate(self, *args, generation_config=None, **kwargs):
|
|
43
|
+
num_beams = kwargs.get(
|
|
44
|
+
"num_beams",
|
|
45
|
+
generation_config.num_beams
|
|
46
|
+
if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
|
|
47
|
+
else 1,
|
|
48
|
+
)
|
|
49
|
+
if num_beams > 1:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Beam search is not supported in RBLNWhisperGenerationMixin. "
|
|
52
|
+
f"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
|
|
53
|
+
f"Please set num_beams=1 for greedy search or adjust your configuration."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return super().generate(*args, **kwargs)
|
|
47
57
|
|
|
48
58
|
def _postprocess_outputs(
|
|
49
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
seek_outputs,
|
|
61
|
+
decoder_input_ids,
|
|
62
|
+
return_token_timestamps,
|
|
63
|
+
generation_config,
|
|
64
|
+
is_shortform,
|
|
65
|
+
seek,
|
|
66
|
+
batch_idx_map,
|
|
50
67
|
):
|
|
51
68
|
# remove all previously passed decoder input ids
|
|
52
69
|
# should happen only if it is the first generated segment
|
|
@@ -64,6 +81,11 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
|
64
81
|
|
|
65
82
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
|
66
83
|
num_frames = getattr(generation_config, "num_frames", None)
|
|
84
|
+
|
|
85
|
+
if num_frames is not None:
|
|
86
|
+
num_frames = num_frames - seek
|
|
87
|
+
num_frames = num_frames[batch_idx_map]
|
|
88
|
+
|
|
67
89
|
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
|
68
90
|
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
69
91
|
seek_outputs,
|
|
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
|
|
46
46
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
47
47
|
mandatory_members = ["main_input_name"]
|
|
48
48
|
|
|
49
|
-
def forward(self, *args: List[torch.Tensor], **kwargs:
|
|
49
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
50
50
|
output = super().forward(*args, **kwargs)
|
|
51
51
|
return BaseModelOutput(last_hidden_state=output)
|
|
52
52
|
|
|
@@ -73,6 +73,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
73
73
|
decoder_input_ids: torch.Tensor = None,
|
|
74
74
|
decoder_attention_mask: torch.Tensor = None,
|
|
75
75
|
cache_position: torch.Tensor = None,
|
|
76
|
+
block_tables: torch.Tensor = None,
|
|
76
77
|
):
|
|
77
78
|
inputs_bsz = decoder_input_ids.shape[0]
|
|
78
79
|
padded_bsz = self.batch_size - inputs_bsz
|
|
@@ -89,11 +90,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
89
90
|
)
|
|
90
91
|
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
91
92
|
|
|
93
|
+
if block_tables is None:
|
|
94
|
+
block_tables = self.default_block_tables
|
|
95
|
+
|
|
92
96
|
outputs = super().forward(
|
|
93
97
|
decoder_input_ids,
|
|
94
98
|
decoder_attention_mask if self.use_attention_mask else None,
|
|
95
99
|
cache_position,
|
|
96
|
-
block_tables=
|
|
100
|
+
block_tables=block_tables,
|
|
97
101
|
)
|
|
98
102
|
|
|
99
103
|
if isinstance(outputs, torch.Tensor):
|
|
@@ -108,6 +112,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
108
112
|
|
|
109
113
|
This model inherits from [`RBLNModel`]. It implements the methods to convert and run
|
|
110
114
|
pre-trained transformers based Whisper model on RBLN devices by:
|
|
115
|
+
|
|
111
116
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
112
117
|
- compiling the resulting graph using the RBLN compiler.
|
|
113
118
|
|
|
@@ -145,7 +150,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
145
150
|
"""
|
|
146
151
|
|
|
147
152
|
auto_model_class = AutoModelForSpeechSeq2Seq
|
|
148
|
-
main_input_name = "
|
|
153
|
+
main_input_name = "input_features"
|
|
154
|
+
_is_stateful = False
|
|
149
155
|
|
|
150
156
|
def __post_init__(self, **kwargs):
|
|
151
157
|
super().__post_init__(**kwargs)
|
|
@@ -249,6 +255,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
249
255
|
|
|
250
256
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
251
257
|
|
|
258
|
+
@classmethod
|
|
259
|
+
def _update_paged_attention_config(
|
|
260
|
+
cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
|
|
261
|
+
):
|
|
262
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
263
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
264
|
+
|
|
265
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
266
|
+
raise NotImplementedError(
|
|
267
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
271
|
+
raise NotImplementedError(
|
|
272
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
273
|
+
)
|
|
274
|
+
|
|
252
275
|
@classmethod
|
|
253
276
|
def _update_rbln_config(
|
|
254
277
|
cls,
|
|
@@ -266,6 +289,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
266
289
|
if rbln_config.dec_max_seq_len is None:
|
|
267
290
|
rbln_config.dec_max_seq_len = model_config.max_length
|
|
268
291
|
|
|
292
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
293
|
+
|
|
269
294
|
enc_input_info = [
|
|
270
295
|
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
|
271
296
|
("block_tables", [1], "int16"),
|
|
@@ -12,14 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_xlm_roberta import
|
|
16
|
-
|
|
17
|
-
RBLNXLMRobertaModelConfig,
|
|
18
|
-
)
|
|
19
|
-
from .modeling_xlm_roberta import (
|
|
20
|
-
RBLNXLMRobertaForSequenceClassification,
|
|
21
|
-
RBLNXLMRobertaModel,
|
|
22
|
-
)
|
|
15
|
+
from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
|
|
16
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
|
|
23
17
|
|
|
24
18
|
|
|
25
19
|
__all__ = [
|