optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -27,401 +27,308 @@ import torch
|
|
27
27
|
from torch import nn
|
28
28
|
from transformers.modeling_attn_mask_utils import (
|
29
29
|
_prepare_4d_causal_attention_mask,
|
30
|
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
31
30
|
)
|
32
31
|
from transformers.modeling_outputs import (
|
33
32
|
BaseModelOutput,
|
34
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
35
33
|
Seq2SeqLMOutput,
|
36
34
|
)
|
37
|
-
from transformers.models.whisper.modeling_whisper import (
|
38
|
-
WhisperAttention,
|
39
|
-
WhisperDecoder,
|
40
|
-
WhisperDecoderLayer,
|
41
|
-
WhisperPositionalEmbedding,
|
42
|
-
WhisperSdpaAttention,
|
43
|
-
)
|
44
35
|
from transformers.utils import logging
|
45
36
|
|
37
|
+
from ....ops import register_rbln_custom_cache_update
|
46
38
|
|
47
|
-
logger = logging.get_logger(__name__)
|
48
39
|
|
40
|
+
logger = logging.get_logger(__name__)
|
49
41
|
|
50
|
-
class _WhisperAttention(WhisperAttention):
|
51
|
-
def forward(
|
52
|
-
self,
|
53
|
-
hidden_states: torch.Tensor,
|
54
|
-
key_value_states: Optional[torch.Tensor] = None,
|
55
|
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
56
|
-
attention_mask: Optional[torch.Tensor] = None,
|
57
|
-
cache_position: Optional[torch.Tensor] = None,
|
58
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
59
|
-
bsz, tgt_len, _ = hidden_states.size()
|
60
|
-
is_cross_attention = key_value_states is not None
|
61
|
-
|
62
|
-
query_states = self.q_proj(hidden_states) * self.scaling
|
63
|
-
|
64
|
-
if is_cross_attention:
|
65
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
66
|
-
if is_dummy_decoder:
|
67
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
68
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
69
|
-
else:
|
70
|
-
key_states = past_key_value[0]
|
71
|
-
value_states = past_key_value[1]
|
72
|
-
else:
|
73
|
-
if self.is_decoder:
|
74
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
75
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
76
|
-
key_states = past_key_value[0].slice_scatter(
|
77
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
78
|
-
)
|
79
|
-
value_states = past_key_value[1].slice_scatter(
|
80
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
81
|
-
)
|
82
|
-
else:
|
83
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
84
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
85
|
-
|
86
|
-
if self.is_decoder:
|
87
|
-
present_key_value = (key_states, value_states)
|
88
|
-
else:
|
89
|
-
present_key_value = None
|
90
|
-
|
91
|
-
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
92
|
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
93
|
-
key_states = key_states.reshape(*proj_shape)
|
94
|
-
value_states = value_states.reshape(*proj_shape)
|
95
|
-
|
96
|
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
97
|
-
src_len = key_states.size(1)
|
98
|
-
if attention_mask is not None:
|
99
|
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
100
|
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
101
42
|
|
102
|
-
|
43
|
+
class WhisperWrapper:
|
44
|
+
def __init__(self, model, rbln_token_timestamps):
|
45
|
+
register_rbln_custom_cache_update()
|
46
|
+
self.encoder = WhisperEncoderWrapper(model)
|
47
|
+
self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
|
103
48
|
|
104
|
-
attn_output = torch.bmm(attn_weights, value_states)
|
105
49
|
|
106
|
-
|
107
|
-
|
50
|
+
class WhisperEncoderWrapper(torch.nn.Module):
|
51
|
+
def __init__(self, model):
|
52
|
+
super().__init__()
|
53
|
+
self.config = model.config
|
54
|
+
self.encoder = model.get_encoder()
|
55
|
+
self.num_heads = self.config.decoder_attention_heads
|
56
|
+
self.d_kv = self.config.d_model // self.num_heads
|
57
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
108
58
|
|
109
|
-
|
110
|
-
|
59
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
60
|
+
return (
|
61
|
+
nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
|
62
|
+
nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
|
63
|
+
)
|
111
64
|
|
112
|
-
|
65
|
+
def forward(
|
66
|
+
self,
|
67
|
+
input_features: Optional[torch.LongTensor],
|
68
|
+
cross_key_values: torch.Tensor,
|
69
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
70
|
+
# 1. get encoder last_hidden_states
|
71
|
+
encoder_outputs = self.encoder(input_features=input_features)
|
72
|
+
last_hidden_states = encoder_outputs[0]
|
113
73
|
|
114
|
-
|
74
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
75
|
+
cross_kv = []
|
76
|
+
batch_size = input_features.shape[0]
|
77
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
78
|
+
past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
79
|
+
past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
115
80
|
|
81
|
+
cross_kv.append(past_k)
|
82
|
+
cross_kv.append(past_v)
|
116
83
|
|
117
|
-
|
118
|
-
def forward(
|
119
|
-
self,
|
120
|
-
hidden_states: torch.Tensor,
|
121
|
-
key_value_states: Optional[torch.Tensor] = None,
|
122
|
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
123
|
-
attention_mask: Optional[torch.Tensor] = None,
|
124
|
-
cache_position: Optional[torch.Tensor] = None,
|
125
|
-
**kwargs,
|
126
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
127
|
-
bsz, tgt_len, _ = hidden_states.size()
|
84
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
128
85
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
if is_cross_attention:
|
134
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
135
|
-
if is_dummy_decoder:
|
136
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
137
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
138
|
-
else:
|
139
|
-
key_states = past_key_value[0]
|
140
|
-
value_states = past_key_value[1]
|
141
|
-
else:
|
142
|
-
if self.is_decoder:
|
143
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
144
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
145
|
-
key_states = past_key_value[0].slice_scatter(
|
146
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
147
|
-
)
|
148
|
-
value_states = past_key_value[1].slice_scatter(
|
149
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
150
|
-
)
|
151
|
-
else:
|
152
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
153
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
154
|
-
|
155
|
-
if self.is_decoder:
|
156
|
-
present_key_value = (key_states, value_states)
|
157
|
-
else:
|
158
|
-
present_key_value = None
|
159
|
-
|
160
|
-
query_states = self._shape(query_states, tgt_len, bsz)
|
161
|
-
|
162
|
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
163
|
-
query_states,
|
164
|
-
key_states,
|
165
|
-
value_states,
|
166
|
-
attn_mask=attention_mask,
|
167
|
-
dropout_p=0.0,
|
168
|
-
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
169
|
-
)
|
86
|
+
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
87
|
+
bidx = torch.tensor(0, dtype=torch.int16)
|
88
|
+
axis = torch.tensor(1, dtype=torch.int16)
|
89
|
+
cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
170
90
|
|
171
|
-
|
172
|
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
91
|
+
return cross_key_values
|
173
92
|
|
174
|
-
attn_output = self.out_proj(attn_output)
|
175
93
|
|
176
|
-
|
94
|
+
class WhisperDecoderWrapper(torch.nn.Module):
|
95
|
+
def __init__(self, model, output_attentions: bool = False):
|
96
|
+
super().__init__()
|
97
|
+
self.config = model.config
|
98
|
+
self.num_layers = self.config.decoder_layers
|
99
|
+
self.proj_out = model.proj_out
|
100
|
+
self.decoder = self.convert_to_rbln_conditional_generation(model)
|
101
|
+
self.output_attentions = output_attentions
|
177
102
|
|
103
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
104
|
+
new_layers = []
|
105
|
+
for layer in model.get_decoder().layers:
|
106
|
+
self_attn = WhisperSelfAttention(layer.self_attn)
|
107
|
+
cross_attn = WhisperCrossAttention(layer.encoder_attn)
|
108
|
+
new_layers.append(WhisperDecoderLayer(layer, self_attn, cross_attn))
|
178
109
|
|
179
|
-
|
110
|
+
decoder_model = WhisperDecoder(model.get_decoder(), new_layers)
|
180
111
|
|
112
|
+
return decoder_model
|
181
113
|
|
182
|
-
class _WhisperDecoderLayer(WhisperDecoderLayer):
|
183
114
|
def forward(
|
184
115
|
self,
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
116
|
+
decoder_input_ids: torch.Tensor,
|
117
|
+
decoder_attention_mask: torch.Tensor,
|
118
|
+
cache_position: torch.Tensor,
|
119
|
+
cross_kv_cache: torch.Tensor,
|
120
|
+
*self_kv_cache: torch.Tensor,
|
121
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
122
|
+
# prepare past_key_values
|
123
|
+
self_past_key_values = ()
|
124
|
+
cross_past_key_values = ()
|
125
|
+
for i in range(0, self.num_layers * 2, 2):
|
126
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
127
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
197
128
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
attention_mask=attention_mask,
|
129
|
+
# Decode
|
130
|
+
sequence_output, self_present_key_values, cross_attentions = self.decoder(
|
131
|
+
input_ids=decoder_input_ids,
|
132
|
+
attention_mask=decoder_attention_mask,
|
203
133
|
cache_position=cache_position,
|
134
|
+
self_past_key_values=self_past_key_values,
|
135
|
+
cross_past_key_values=cross_past_key_values,
|
204
136
|
)
|
205
|
-
hidden_states = residual + hidden_states
|
206
137
|
|
207
|
-
|
208
|
-
residual = hidden_states
|
209
|
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
210
|
-
cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
|
211
|
-
if output_attentions:
|
212
|
-
hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
|
213
|
-
self.encoder_attn,
|
214
|
-
hidden_states=hidden_states,
|
215
|
-
key_value_states=encoder_hidden_states,
|
216
|
-
past_key_value=cross_attn_past_key_value,
|
217
|
-
cache_position=cache_position,
|
218
|
-
)
|
219
|
-
else:
|
220
|
-
hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
221
|
-
self.encoder_attn,
|
222
|
-
hidden_states=hidden_states,
|
223
|
-
key_value_states=encoder_hidden_states,
|
224
|
-
past_key_value=cross_attn_past_key_value,
|
225
|
-
cache_position=cache_position,
|
226
|
-
)
|
227
|
-
hidden_states = residual + hidden_states
|
228
|
-
present_key_value = present_key_value + cross_attn_present_key_value
|
138
|
+
lm_logits = self.proj_out(sequence_output)
|
229
139
|
|
230
|
-
|
231
|
-
|
232
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
233
|
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
234
|
-
hidden_states = self.fc2(hidden_states)
|
235
|
-
hidden_states = residual + hidden_states
|
140
|
+
outputs = (lm_logits,)
|
141
|
+
outputs += self_present_key_values
|
236
142
|
|
237
|
-
|
143
|
+
if self.output_attentions:
|
144
|
+
# deocder's cross attention is used for token_timestamps
|
145
|
+
cross_attention = torch.stack(cross_attentions, dim=0)
|
146
|
+
outputs += (cross_attention,)
|
238
147
|
|
148
|
+
return outputs
|
239
149
|
|
240
|
-
class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
|
241
|
-
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
|
242
|
-
if position_ids is None:
|
243
|
-
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
244
|
-
else:
|
245
|
-
return self.weight[position_ids]
|
246
150
|
|
151
|
+
class WhisperDecoder(nn.Module):
|
152
|
+
def __init__(self, model, layers, **kwargs):
|
153
|
+
super().__init__()
|
154
|
+
self._original_mod = model
|
155
|
+
self.layers = nn.ModuleList(layers)
|
156
|
+
self.embed_tokens = model.embed_tokens
|
157
|
+
self.layer_norm = model.layer_norm
|
158
|
+
self.embed_positions = model.embed_positions
|
247
159
|
|
248
|
-
class _WhisperDecoder(WhisperDecoder):
|
249
160
|
def forward(
|
250
161
|
self,
|
251
162
|
input_ids: Optional[torch.Tensor] = None,
|
252
163
|
attention_mask: Optional[torch.Tensor] = None,
|
253
|
-
|
254
|
-
|
164
|
+
self_past_key_values: Optional[torch.Tensor] = None,
|
165
|
+
cross_past_key_values: Optional[torch.Tensor] = None,
|
255
166
|
cache_position: Optional[torch.Tensor] = None,
|
256
|
-
attn_impl: str = "eager",
|
257
|
-
output_attentions: bool = False,
|
258
|
-
**kwargs,
|
259
167
|
):
|
260
168
|
input_shape = input_ids.size()
|
261
169
|
input_ids = input_ids.view(-1, input_shape[-1])
|
262
170
|
|
263
171
|
# positional embeding
|
264
172
|
inputs_embeds = self.embed_tokens(input_ids)
|
265
|
-
positions =
|
266
|
-
self.embed_positions, input_ids, cache_position, cache_position
|
267
|
-
)
|
173
|
+
positions = self.embed_positions(input_ids, position_ids=cache_position)
|
268
174
|
hidden_states = inputs_embeds + positions
|
269
175
|
|
270
176
|
# prepare casual_attn_mask
|
271
|
-
|
272
|
-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
273
|
-
attention_mask, input_shape, inputs_embeds, cache_position
|
274
|
-
)
|
275
|
-
else:
|
276
|
-
attention_mask = _prepare_4d_causal_attention_mask(
|
277
|
-
attention_mask, input_shape, inputs_embeds, cache_position
|
278
|
-
)
|
177
|
+
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
279
178
|
|
280
|
-
|
281
|
-
|
179
|
+
self_present_key_values = ()
|
180
|
+
cross_attentions = ()
|
282
181
|
# iterate decoder_layer
|
283
|
-
for
|
284
|
-
|
285
|
-
|
286
|
-
|
182
|
+
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
183
|
+
self_past_key_values, cross_past_key_values, self.layers
|
184
|
+
):
|
185
|
+
layer_outputs = decoder_layer(
|
287
186
|
hidden_states,
|
288
187
|
attention_mask=attention_mask,
|
289
|
-
|
290
|
-
|
188
|
+
self_past_key_value=self_past_key_value,
|
189
|
+
cross_past_key_value=cross_past_key_value,
|
291
190
|
cache_position=cache_position,
|
292
|
-
attn_impl=attn_impl,
|
293
|
-
output_attentions=output_attentions,
|
294
191
|
)
|
295
192
|
hidden_states = layer_outputs[0]
|
193
|
+
self_present_key_values += layer_outputs[1]
|
194
|
+
cross_attentions += (layer_outputs[2],)
|
296
195
|
|
297
|
-
next_decoder_cache += (layer_outputs[1],)
|
298
|
-
if output_attentions:
|
299
|
-
all_cross_attentions += (layer_outputs[2],)
|
300
|
-
|
301
|
-
# layer_norm
|
302
196
|
hidden_states = self.layer_norm(hidden_states)
|
303
197
|
|
304
|
-
return
|
305
|
-
last_hidden_state=hidden_states,
|
306
|
-
past_key_values=next_decoder_cache,
|
307
|
-
cross_attentions=all_cross_attentions,
|
308
|
-
)
|
198
|
+
return hidden_states, self_present_key_values, cross_attentions
|
309
199
|
|
310
200
|
|
311
|
-
class
|
312
|
-
def __init__(self,
|
201
|
+
class WhisperDecoderLayer(nn.Module):
|
202
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
313
203
|
super().__init__()
|
314
|
-
self.
|
315
|
-
self.
|
316
|
-
self.
|
317
|
-
self.
|
318
|
-
self.
|
319
|
-
self.
|
204
|
+
self._original_mod = decoder_layer
|
205
|
+
self.self_attn = self_attn
|
206
|
+
self.encoder_attn = cross_attn
|
207
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
208
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
209
|
+
self.final_layer_norm = decoder_layer.final_layer_norm
|
210
|
+
self.activation_fn = decoder_layer.activation_fn
|
211
|
+
self.fc1 = decoder_layer.fc1
|
212
|
+
self.fc2 = decoder_layer.fc2
|
320
213
|
|
321
214
|
def forward(
|
322
215
|
self,
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
) ->
|
329
|
-
#
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
cross_kv_cache[i],
|
337
|
-
cross_kv_cache[i + 1],
|
338
|
-
),
|
339
|
-
)
|
340
|
-
|
341
|
-
# Decode
|
342
|
-
decoder_outputs = _WhisperDecoder.forward(
|
343
|
-
self.decoder,
|
344
|
-
input_ids=decoder_input_ids,
|
345
|
-
attention_mask=decoder_attention_mask,
|
216
|
+
hidden_states: torch.Tensor,
|
217
|
+
attention_mask: Optional[torch.Tensor] = None,
|
218
|
+
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
219
|
+
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
220
|
+
cache_position: Optional[torch.Tensor] = None,
|
221
|
+
) -> torch.Tensor:
|
222
|
+
# Self Attention Block
|
223
|
+
residual = hidden_states
|
224
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
225
|
+
hidden_states, _, self_present_key_value = self.self_attn(
|
226
|
+
hidden_states=hidden_states,
|
227
|
+
past_key_value=self_past_key_value,
|
228
|
+
attention_mask=attention_mask,
|
346
229
|
cache_position=cache_position,
|
347
|
-
past_key_values=kv_cache,
|
348
|
-
encoder_hidden_states=torch.tensor([1]),
|
349
|
-
attn_impl=self.attn_impl,
|
350
|
-
output_attentions=self.output_attentions,
|
351
230
|
)
|
352
|
-
|
353
|
-
lm_logits = self.proj_out(sequence_output)
|
231
|
+
hidden_states = residual + hidden_states
|
354
232
|
|
355
|
-
#
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
233
|
+
# Cross-Attention Block
|
234
|
+
residual = hidden_states
|
235
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
236
|
+
hidden_states, cross_attn_weights, cross_present_key_value = self.encoder_attn(
|
237
|
+
hidden_states=hidden_states,
|
238
|
+
past_key_value=cross_past_key_value,
|
239
|
+
)
|
240
|
+
hidden_states = residual + hidden_states
|
362
241
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
242
|
+
# Fully Connected Block
|
243
|
+
residual = hidden_states
|
244
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
245
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
246
|
+
hidden_states = self.fc2(hidden_states)
|
247
|
+
hidden_states = residual + hidden_states
|
369
248
|
|
249
|
+
return hidden_states, self_present_key_value, cross_attn_weights
|
370
250
|
|
371
|
-
|
372
|
-
|
251
|
+
|
252
|
+
class WhisperAttention(nn.Module):
|
253
|
+
def __init__(self, attn):
|
373
254
|
super().__init__()
|
374
|
-
self.
|
375
|
-
self.
|
376
|
-
self.
|
377
|
-
self.
|
378
|
-
self.
|
379
|
-
self.
|
380
|
-
self.
|
381
|
-
self.
|
382
|
-
self.
|
383
|
-
|
255
|
+
self._original_mod = attn
|
256
|
+
self.q_proj = attn.q_proj
|
257
|
+
self.k_proj = attn.k_proj
|
258
|
+
self.v_proj = attn.v_proj
|
259
|
+
self.out_proj = attn.out_proj
|
260
|
+
self.num_heads = attn.num_heads
|
261
|
+
self.embed_dim = attn.embed_dim
|
262
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
263
|
+
self.scaling = self.head_dim**-0.5
|
264
|
+
|
265
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
266
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
267
|
+
|
268
|
+
|
269
|
+
class WhisperSelfAttention(WhisperAttention):
|
270
|
+
def rbln_cache_update(
|
271
|
+
self,
|
272
|
+
past_key_value: torch.Tensor,
|
273
|
+
key_states: torch.Tensor,
|
274
|
+
value_states: torch.Tensor,
|
275
|
+
cache_position: torch.Tensor,
|
276
|
+
):
|
277
|
+
s_idx = torch.tensor(cache_position, dtype=torch.int16)
|
278
|
+
axis = torch.tensor(2, dtype=torch.int16)
|
279
|
+
|
280
|
+
key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
|
281
|
+
value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
|
282
|
+
return key_states, value_states
|
384
283
|
|
385
284
|
def forward(
|
386
285
|
self,
|
387
|
-
|
388
|
-
|
389
|
-
|
286
|
+
hidden_states: torch.Tensor,
|
287
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
288
|
+
attention_mask: Optional[torch.Tensor] = None,
|
289
|
+
cache_position: Optional[torch.Tensor] = None,
|
290
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
291
|
+
bsz, tgt_len, _ = hidden_states.size()
|
292
|
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
293
|
+
query_states = query_states * self.scaling
|
390
294
|
|
391
|
-
|
295
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
296
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
297
|
+
key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
|
392
298
|
|
393
|
-
|
394
|
-
|
299
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
300
|
+
attn_weights = attn_weights + attention_mask
|
301
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
395
302
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
|
402
|
-
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
403
|
-
dummy_past_key_value.append(layer_pkv)
|
303
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
304
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
305
|
+
attn_output = attn_output.transpose(1, 2)
|
306
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
307
|
+
attn_output = self.out_proj(attn_output)
|
404
308
|
|
405
|
-
|
406
|
-
decoder_attention_mask[:, :1] = 1
|
309
|
+
return attn_output, attn_weights, (key_states, value_states)
|
407
310
|
|
408
|
-
decoder_outputs = _WhisperDecoder.forward(
|
409
|
-
self.decoder,
|
410
|
-
input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
|
411
|
-
attention_mask=decoder_attention_mask,
|
412
|
-
cache_position=torch.tensor(0, dtype=torch.int32),
|
413
|
-
encoder_hidden_states=last_hidden_states,
|
414
|
-
past_key_values=dummy_past_key_value,
|
415
|
-
attn_impl=self.attn_impl,
|
416
|
-
output_attentions=False,
|
417
|
-
)
|
418
311
|
|
419
|
-
|
312
|
+
class WhisperCrossAttention(WhisperSelfAttention):
|
313
|
+
def forward(
|
314
|
+
self,
|
315
|
+
hidden_states: torch.Tensor,
|
316
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
317
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
318
|
+
batch_size, query_len, _ = hidden_states.size()
|
319
|
+
query_states = self._shape(self.q_proj(hidden_states), query_len, batch_size)
|
320
|
+
query_states = query_states * self.scaling
|
420
321
|
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
322
|
+
key_states = past_key_value[0]
|
323
|
+
value_states = past_key_value[1]
|
324
|
+
|
325
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
326
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
327
|
+
|
328
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
329
|
+
attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
|
330
|
+
attn_output = attn_output.transpose(1, 2)
|
331
|
+
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
332
|
+
attn_output = self.out_proj(attn_output)
|
426
333
|
|
427
|
-
return
|
334
|
+
return attn_output, attn_weights, (key_states, value_states)
|