optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py
ADDED
@@ -0,0 +1,331 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Optional, Tuple, Union
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
29
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
30
|
+
from transformers.utils import logging
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class TimeSeriesTransformersWrapper:
|
37
|
+
def __init__(self, model, num_parallel_samples):
|
38
|
+
self.encoder = TimeSeriesTransformersEncoderWrapper(model)
|
39
|
+
self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
|
40
|
+
|
41
|
+
|
42
|
+
class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
|
43
|
+
def __init__(self, model):
|
44
|
+
super().__init__()
|
45
|
+
self.config = model.config
|
46
|
+
self.encoder = model.get_encoder()
|
47
|
+
self.num_heads = self.config.decoder_attention_heads
|
48
|
+
self.d_kv = self.config.d_model // self.num_heads
|
49
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
50
|
+
|
51
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
52
|
+
return (
|
53
|
+
nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
|
54
|
+
nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
|
55
|
+
)
|
56
|
+
|
57
|
+
def forward(
|
58
|
+
self,
|
59
|
+
inputs_embeds: torch.Tensor,
|
60
|
+
cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
|
61
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
62
|
+
# 1. get encoder last_hidden_states
|
63
|
+
encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
|
64
|
+
last_hidden_states = encoder_outputs[0]
|
65
|
+
|
66
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
67
|
+
cross_kv = []
|
68
|
+
batch_size = inputs_embeds.shape[0]
|
69
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
70
|
+
past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
71
|
+
past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
72
|
+
|
73
|
+
cross_kv.append(past_k)
|
74
|
+
cross_kv.append(past_v)
|
75
|
+
|
76
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
77
|
+
|
78
|
+
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
79
|
+
bidx = torch.tensor(0, dtype=torch.int16)
|
80
|
+
axis = torch.tensor(1, dtype=torch.int16)
|
81
|
+
enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
82
|
+
|
83
|
+
return enc_output
|
84
|
+
|
85
|
+
|
86
|
+
class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
|
87
|
+
def __init__(self, model, num_parallel_samples):
|
88
|
+
super().__init__()
|
89
|
+
self.config = model.config
|
90
|
+
self.num_layers = self.config.decoder_layers
|
91
|
+
self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
|
92
|
+
self.parameter_projection = model.parameter_projection
|
93
|
+
|
94
|
+
def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
|
95
|
+
new_layers = []
|
96
|
+
for layer in model.get_decoder().layers:
|
97
|
+
self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
|
98
|
+
cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
|
99
|
+
new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
|
100
|
+
|
101
|
+
decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
|
102
|
+
|
103
|
+
return decoder_model
|
104
|
+
|
105
|
+
def forward(
|
106
|
+
self,
|
107
|
+
inputs_embeds: torch.Tensor,
|
108
|
+
decoder_attention_mask: torch.Tensor,
|
109
|
+
cache_position: torch.Tensor,
|
110
|
+
block_tables: torch.Tensor,
|
111
|
+
cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
|
112
|
+
*self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
|
113
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
114
|
+
# prepare past_key_values
|
115
|
+
self_past_key_values = ()
|
116
|
+
cross_past_key_values = ()
|
117
|
+
for i in range(0, self.num_layers * 2, 2):
|
118
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
119
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
120
|
+
|
121
|
+
# Decode
|
122
|
+
last_hidden_states = self.decoder(
|
123
|
+
inputs_embeds=inputs_embeds,
|
124
|
+
attention_mask=decoder_attention_mask,
|
125
|
+
cache_position=cache_position,
|
126
|
+
block_tables=block_tables,
|
127
|
+
self_past_key_values=self_past_key_values,
|
128
|
+
cross_past_key_values=cross_past_key_values,
|
129
|
+
)
|
130
|
+
|
131
|
+
params = self.parameter_projection(last_hidden_states[:, -1:])
|
132
|
+
|
133
|
+
outputs = ()
|
134
|
+
outputs += (params,)
|
135
|
+
outputs += (last_hidden_states,)
|
136
|
+
|
137
|
+
return outputs
|
138
|
+
|
139
|
+
|
140
|
+
class TimeSeriesTransformersDecoder(nn.Module):
|
141
|
+
def __init__(self, model, layers, **kwargs):
|
142
|
+
super().__init__()
|
143
|
+
self._original_mod = model
|
144
|
+
self.config = model.config
|
145
|
+
self.layers = nn.ModuleList(layers)
|
146
|
+
self.value_embedding = model.value_embedding
|
147
|
+
self.embed_positions = model.embed_positions
|
148
|
+
self.layernorm_embedding = model.layernorm_embedding
|
149
|
+
|
150
|
+
def forward(
|
151
|
+
self,
|
152
|
+
inputs_embeds: torch.Tensor = None,
|
153
|
+
attention_mask: Optional[torch.Tensor] = None,
|
154
|
+
self_past_key_values: Optional[torch.Tensor] = None,
|
155
|
+
cross_past_key_values: Optional[torch.Tensor] = None,
|
156
|
+
cache_position: Optional[torch.Tensor] = None,
|
157
|
+
block_tables: torch.Tensor = None,
|
158
|
+
):
|
159
|
+
input_shape = inputs_embeds.size()[:-1]
|
160
|
+
|
161
|
+
# prepare casual_attn_mask
|
162
|
+
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
163
|
+
|
164
|
+
hidden_states = self.value_embedding(inputs_embeds)
|
165
|
+
embed_pos = self.embed_positions.weight[cache_position + self.config.context_length]
|
166
|
+
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
|
167
|
+
|
168
|
+
# iterate decoder_layer
|
169
|
+
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
170
|
+
self_past_key_values, cross_past_key_values, self.layers
|
171
|
+
):
|
172
|
+
hidden_states = decoder_layer(
|
173
|
+
hidden_states,
|
174
|
+
attention_mask=attention_mask,
|
175
|
+
self_past_key_value=self_past_key_value,
|
176
|
+
cross_past_key_value=cross_past_key_value,
|
177
|
+
cache_position=cache_position,
|
178
|
+
block_tables=block_tables,
|
179
|
+
)
|
180
|
+
|
181
|
+
return hidden_states
|
182
|
+
|
183
|
+
|
184
|
+
class TimeSeriesTransformersDecoderLayer(nn.Module):
|
185
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
186
|
+
super().__init__()
|
187
|
+
self._original_mod = decoder_layer
|
188
|
+
self.self_attn = self_attn
|
189
|
+
self.encoder_attn = cross_attn
|
190
|
+
self.embed_dim = decoder_layer.embed_dim
|
191
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
192
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
193
|
+
self.final_layer_norm = decoder_layer.final_layer_norm
|
194
|
+
self.activation_fn = decoder_layer.activation_fn
|
195
|
+
self.fc1 = decoder_layer.fc1
|
196
|
+
self.fc2 = decoder_layer.fc2
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
self,
|
200
|
+
hidden_states: torch.Tensor,
|
201
|
+
attention_mask: Optional[torch.Tensor] = None,
|
202
|
+
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
203
|
+
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
204
|
+
cache_position: Optional[torch.Tensor] = None,
|
205
|
+
block_tables: torch.Tensor = None,
|
206
|
+
) -> torch.Tensor:
|
207
|
+
# Self Attention Block
|
208
|
+
residual = hidden_states
|
209
|
+
hidden_states = self.self_attn(
|
210
|
+
hidden_states=hidden_states,
|
211
|
+
past_key_value=self_past_key_value,
|
212
|
+
attention_mask=attention_mask,
|
213
|
+
cache_position=cache_position,
|
214
|
+
block_tables=block_tables,
|
215
|
+
)
|
216
|
+
hidden_states = residual + hidden_states
|
217
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
218
|
+
|
219
|
+
# Cross-Attention Block
|
220
|
+
residual = hidden_states
|
221
|
+
hidden_states = self.encoder_attn(
|
222
|
+
hidden_states=hidden_states,
|
223
|
+
past_key_value=cross_past_key_value,
|
224
|
+
# attention_mask=encoder_attention_mask,
|
225
|
+
)
|
226
|
+
hidden_states = residual + hidden_states
|
227
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
228
|
+
|
229
|
+
# Fully Connected Block
|
230
|
+
residual = hidden_states
|
231
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
232
|
+
hidden_states = self.fc2(hidden_states)
|
233
|
+
hidden_states = residual + hidden_states
|
234
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
235
|
+
|
236
|
+
return hidden_states
|
237
|
+
|
238
|
+
|
239
|
+
class TimeSeriesTransformersAttention(nn.Module):
|
240
|
+
def __init__(self, attn, num_parallel_samples):
|
241
|
+
super().__init__()
|
242
|
+
self._original_mod = attn
|
243
|
+
self.q_proj = attn.q_proj
|
244
|
+
self.k_proj = attn.k_proj
|
245
|
+
self.v_proj = attn.v_proj
|
246
|
+
self.out_proj = attn.out_proj
|
247
|
+
self.num_heads = attn.num_heads
|
248
|
+
self.embed_dim = attn.embed_dim
|
249
|
+
self.head_dim = attn.head_dim
|
250
|
+
self.scaling = attn.scaling
|
251
|
+
self.num_parallel_samples = num_parallel_samples
|
252
|
+
|
253
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
254
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
255
|
+
|
256
|
+
|
257
|
+
class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
|
258
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
259
|
+
return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
|
260
|
+
|
261
|
+
def forward(
|
262
|
+
self,
|
263
|
+
hidden_states: torch.Tensor,
|
264
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
265
|
+
attention_mask: Optional[torch.Tensor] = None,
|
266
|
+
cache_position: Optional[torch.Tensor] = None,
|
267
|
+
block_tables: Optional[torch.Tensor] = None,
|
268
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
269
|
+
bsz, tgt_len, _ = hidden_states.size()
|
270
|
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
271
|
+
query_states = query_states * self.scaling
|
272
|
+
|
273
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
274
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
275
|
+
|
276
|
+
block_size = past_key_value[0].shape[-2]
|
277
|
+
attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
|
278
|
+
q=query_states,
|
279
|
+
k=key_states,
|
280
|
+
v=value_states,
|
281
|
+
mask=attention_mask.unsqueeze(2),
|
282
|
+
kcache=past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
283
|
+
vcache=past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
|
284
|
+
seq=cache_position.expand(bsz, 1),
|
285
|
+
scale=torch.tensor(1.0, dtype=torch.float32), # scale
|
286
|
+
block_table=block_tables,
|
287
|
+
block_size=block_size,
|
288
|
+
)
|
289
|
+
|
290
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
291
|
+
attn_output = attn_output.transpose(1, 2)
|
292
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
293
|
+
attn_output = self.out_proj(attn_output)
|
294
|
+
|
295
|
+
return attn_output
|
296
|
+
|
297
|
+
|
298
|
+
class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
|
299
|
+
def forward(
|
300
|
+
self,
|
301
|
+
hidden_states: torch.Tensor,
|
302
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
303
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
304
|
+
batch_size, query_len, _ = hidden_states.size()
|
305
|
+
query_states = (
|
306
|
+
self.q_proj(hidden_states)
|
307
|
+
.view(
|
308
|
+
batch_size // self.num_parallel_samples,
|
309
|
+
self.num_parallel_samples,
|
310
|
+
query_len,
|
311
|
+
self.num_heads,
|
312
|
+
self.head_dim,
|
313
|
+
)
|
314
|
+
.transpose(2, 3)
|
315
|
+
)
|
316
|
+
query_states = query_states * self.scaling
|
317
|
+
|
318
|
+
key_states = past_key_value[0].unsqueeze(1)
|
319
|
+
value_states = past_key_value[1].unsqueeze(1)
|
320
|
+
|
321
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
|
322
|
+
attn_weights = attn_weights
|
323
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
324
|
+
|
325
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
326
|
+
attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
|
327
|
+
attn_output = attn_output.transpose(1, 2)
|
328
|
+
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
329
|
+
attn_output = self.out_proj(attn_output)
|
330
|
+
|
331
|
+
return attn_output
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNModelForMaskedLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
|
19
|
+
rbln_model_input_names = ["input_values"]
|
@@ -12,26 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, Union
|
16
15
|
|
17
16
|
import torch
|
18
|
-
from transformers import AutoModelForMaskedLM,
|
17
|
+
from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
|
19
18
|
from transformers.modeling_outputs import CausalLMOutput
|
20
19
|
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from ....utils.logging import get_logger
|
24
|
-
|
25
|
-
|
26
|
-
logger = get_logger(__name__)
|
27
|
-
|
28
|
-
if TYPE_CHECKING:
|
29
|
-
from transformers import (
|
30
|
-
AutoFeatureExtractor,
|
31
|
-
AutoProcessor,
|
32
|
-
AutoTokenizer,
|
33
|
-
PretrainedConfig,
|
34
|
-
)
|
20
|
+
from ...modeling_generic import RBLNModelForMaskedLM
|
21
|
+
from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
|
35
22
|
|
36
23
|
|
37
24
|
class _Wav2Vec2(torch.nn.Module):
|
@@ -44,11 +31,11 @@ class _Wav2Vec2(torch.nn.Module):
|
|
44
31
|
return self.model.lm_head(output[0])
|
45
32
|
|
46
33
|
|
47
|
-
class RBLNWav2Vec2ForCTC(
|
34
|
+
class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
|
48
35
|
"""
|
49
36
|
Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
50
37
|
|
51
|
-
This model inherits from [`
|
38
|
+
This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
|
52
39
|
library implements for all its model.
|
53
40
|
|
54
41
|
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
@@ -58,60 +45,10 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
58
45
|
|
59
46
|
main_input_name = "input_values"
|
60
47
|
auto_model_class = AutoModelForMaskedLM
|
48
|
+
rbln_dtype = "float32"
|
49
|
+
output_class = CausalLMOutput
|
50
|
+
output_key = "logits"
|
61
51
|
|
62
52
|
@classmethod
|
63
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
53
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
|
64
54
|
return _Wav2Vec2(model).eval()
|
65
|
-
|
66
|
-
@classmethod
|
67
|
-
def _get_rbln_config(
|
68
|
-
cls,
|
69
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
70
|
-
model_config: "PretrainedConfig",
|
71
|
-
rbln_kwargs: Dict[str, Any] = {},
|
72
|
-
) -> RBLNConfig:
|
73
|
-
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
74
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
75
|
-
|
76
|
-
if rbln_max_seq_len is None:
|
77
|
-
for tokenizer in preprocessors:
|
78
|
-
if hasattr(tokenizer, "model_max_length"):
|
79
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
80
|
-
break
|
81
|
-
if rbln_max_seq_len is None:
|
82
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
83
|
-
|
84
|
-
if rbln_batch_size is None:
|
85
|
-
rbln_batch_size = 1
|
86
|
-
|
87
|
-
input_info = [
|
88
|
-
(
|
89
|
-
"input_values",
|
90
|
-
[
|
91
|
-
rbln_batch_size,
|
92
|
-
rbln_max_seq_len,
|
93
|
-
],
|
94
|
-
"float32",
|
95
|
-
),
|
96
|
-
]
|
97
|
-
|
98
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
99
|
-
|
100
|
-
rbln_config = RBLNConfig(
|
101
|
-
rbln_cls=cls.__name__,
|
102
|
-
compile_cfgs=[rbln_compile_config],
|
103
|
-
rbln_kwargs=rbln_kwargs,
|
104
|
-
)
|
105
|
-
|
106
|
-
rbln_config.model_cfg.update(
|
107
|
-
{
|
108
|
-
"max_seq_len": rbln_max_seq_len,
|
109
|
-
"batch_size": rbln_batch_size,
|
110
|
-
}
|
111
|
-
)
|
112
|
-
|
113
|
-
return rbln_config
|
114
|
-
|
115
|
-
def forward(self, input_values: "torch.Tensor", **kwargs):
|
116
|
-
outputs = super().forward(input_values, **kwargs)
|
117
|
-
return CausalLMOutput(logits=outputs)
|
@@ -12,4 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from ....ops import paged_add_softmax_attn_decode
|
16
|
+
from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
|
15
17
|
from .modeling_whisper import RBLNWhisperForConditionalGeneration
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import rebel
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....utils.logging import get_logger
|
19
|
+
|
20
|
+
|
21
|
+
logger = get_logger()
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
batch_size: int = None,
|
28
|
+
token_timestamps: bool = None,
|
29
|
+
use_attention_mask: bool = None,
|
30
|
+
enc_max_seq_len: int = None,
|
31
|
+
dec_max_seq_len: int = None,
|
32
|
+
**kwargs,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Args:
|
36
|
+
batch_size (int, optional): The batch size for inference. Defaults to 1.
|
37
|
+
token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
|
38
|
+
use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
|
39
|
+
set to True for RBLN-CA02 devices.
|
40
|
+
enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
|
41
|
+
dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
|
42
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
43
|
+
|
44
|
+
Raises:
|
45
|
+
ValueError: If batch_size is not a positive integer.
|
46
|
+
"""
|
47
|
+
super().__init__(**kwargs)
|
48
|
+
|
49
|
+
self.batch_size = batch_size or 1
|
50
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
51
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
52
|
+
|
53
|
+
self.token_timestamps = token_timestamps or False
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|