optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +7 -0
  68. optimum/rbln/utils/logging.py +37 -0
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  79. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -27,401 +27,308 @@ import torch
27
27
  from torch import nn
28
28
  from transformers.modeling_attn_mask_utils import (
29
29
  _prepare_4d_causal_attention_mask,
30
- _prepare_4d_causal_attention_mask_for_sdpa,
31
30
  )
32
31
  from transformers.modeling_outputs import (
33
32
  BaseModelOutput,
34
- BaseModelOutputWithPastAndCrossAttentions,
35
33
  Seq2SeqLMOutput,
36
34
  )
37
- from transformers.models.whisper.modeling_whisper import (
38
- WhisperAttention,
39
- WhisperDecoder,
40
- WhisperDecoderLayer,
41
- WhisperPositionalEmbedding,
42
- WhisperSdpaAttention,
43
- )
44
35
  from transformers.utils import logging
45
36
 
37
+ from ....ops import register_rbln_custom_cache_update
46
38
 
47
- logger = logging.get_logger(__name__)
48
39
 
40
+ logger = logging.get_logger(__name__)
49
41
 
50
- class _WhisperAttention(WhisperAttention):
51
- def forward(
52
- self,
53
- hidden_states: torch.Tensor,
54
- key_value_states: Optional[torch.Tensor] = None,
55
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
- attention_mask: Optional[torch.Tensor] = None,
57
- cache_position: Optional[torch.Tensor] = None,
58
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
59
- bsz, tgt_len, _ = hidden_states.size()
60
- is_cross_attention = key_value_states is not None
61
-
62
- query_states = self.q_proj(hidden_states) * self.scaling
63
-
64
- if is_cross_attention:
65
- is_dummy_decoder = len(key_value_states.shape) > 1
66
- if is_dummy_decoder:
67
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
68
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
69
- else:
70
- key_states = past_key_value[0]
71
- value_states = past_key_value[1]
72
- else:
73
- if self.is_decoder:
74
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
75
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
76
- key_states = past_key_value[0].slice_scatter(
77
- key_states, dim=2, start=cache_position, end=cache_position + 1
78
- )
79
- value_states = past_key_value[1].slice_scatter(
80
- value_states, dim=2, start=cache_position, end=cache_position + 1
81
- )
82
- else:
83
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
84
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
85
-
86
- if self.is_decoder:
87
- present_key_value = (key_states, value_states)
88
- else:
89
- present_key_value = None
90
-
91
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
92
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
93
- key_states = key_states.reshape(*proj_shape)
94
- value_states = value_states.reshape(*proj_shape)
95
-
96
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
97
- src_len = key_states.size(1)
98
- if attention_mask is not None:
99
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
100
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
101
42
 
102
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
43
+ class WhisperWrapper:
44
+ def __init__(self, model, rbln_token_timestamps):
45
+ register_rbln_custom_cache_update()
46
+ self.encoder = WhisperEncoderWrapper(model)
47
+ self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
103
48
 
104
- attn_output = torch.bmm(attn_weights, value_states)
105
49
 
106
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
107
- attn_output = attn_output.transpose(1, 2)
50
+ class WhisperEncoderWrapper(torch.nn.Module):
51
+ def __init__(self, model):
52
+ super().__init__()
53
+ self.config = model.config
54
+ self.encoder = model.get_encoder()
55
+ self.num_heads = self.config.decoder_attention_heads
56
+ self.d_kv = self.config.d_model // self.num_heads
57
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
108
58
 
109
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
110
- attn_output = self.out_proj(attn_output)
59
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
60
+ return (
61
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
62
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
63
+ )
111
64
 
112
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
65
+ def forward(
66
+ self,
67
+ input_features: Optional[torch.LongTensor],
68
+ cross_key_values: torch.Tensor,
69
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
70
+ # 1. get encoder last_hidden_states
71
+ encoder_outputs = self.encoder(input_features=input_features)
72
+ last_hidden_states = encoder_outputs[0]
113
73
 
114
- return attn_output, attn_weights, present_key_value
74
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
75
+ cross_kv = []
76
+ batch_size = input_features.shape[0]
77
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
78
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
79
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
115
80
 
81
+ cross_kv.append(past_k)
82
+ cross_kv.append(past_v)
116
83
 
117
- class _WhisperSdpaAttention(WhisperSdpaAttention):
118
- def forward(
119
- self,
120
- hidden_states: torch.Tensor,
121
- key_value_states: Optional[torch.Tensor] = None,
122
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
123
- attention_mask: Optional[torch.Tensor] = None,
124
- cache_position: Optional[torch.Tensor] = None,
125
- **kwargs,
126
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
127
- bsz, tgt_len, _ = hidden_states.size()
84
+ cross_kv = torch.stack(cross_kv, dim=0)
128
85
 
129
- is_cross_attention = key_value_states is not None
130
-
131
- query_states = self.q_proj(hidden_states)
132
-
133
- if is_cross_attention:
134
- is_dummy_decoder = len(key_value_states.shape) > 1
135
- if is_dummy_decoder:
136
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
137
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
138
- else:
139
- key_states = past_key_value[0]
140
- value_states = past_key_value[1]
141
- else:
142
- if self.is_decoder:
143
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
144
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
145
- key_states = past_key_value[0].slice_scatter(
146
- key_states, dim=2, start=cache_position, end=cache_position + 1
147
- )
148
- value_states = past_key_value[1].slice_scatter(
149
- value_states, dim=2, start=cache_position, end=cache_position + 1
150
- )
151
- else:
152
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
153
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
154
-
155
- if self.is_decoder:
156
- present_key_value = (key_states, value_states)
157
- else:
158
- present_key_value = None
159
-
160
- query_states = self._shape(query_states, tgt_len, bsz)
161
-
162
- attn_output = torch.nn.functional.scaled_dot_product_attention(
163
- query_states,
164
- key_states,
165
- value_states,
166
- attn_mask=attention_mask,
167
- dropout_p=0.0,
168
- is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
169
- )
86
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
87
+ bidx = torch.tensor(0, dtype=torch.int16)
88
+ axis = torch.tensor(1, dtype=torch.int16)
89
+ cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
170
90
 
171
- attn_output = attn_output.transpose(1, 2)
172
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
91
+ return cross_key_values
173
92
 
174
- attn_output = self.out_proj(attn_output)
175
93
 
176
- return attn_output, None, present_key_value
94
+ class WhisperDecoderWrapper(torch.nn.Module):
95
+ def __init__(self, model, output_attentions: bool = False):
96
+ super().__init__()
97
+ self.config = model.config
98
+ self.num_layers = self.config.decoder_layers
99
+ self.proj_out = model.proj_out
100
+ self.decoder = self.convert_to_rbln_conditional_generation(model)
101
+ self.output_attentions = output_attentions
177
102
 
103
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
104
+ new_layers = []
105
+ for layer in model.get_decoder().layers:
106
+ self_attn = WhisperSelfAttention(layer.self_attn)
107
+ cross_attn = WhisperCrossAttention(layer.encoder_attn)
108
+ new_layers.append(WhisperDecoderLayer(layer, self_attn, cross_attn))
178
109
 
179
- ATTN_FORWARD_MAP = {"eager": _WhisperAttention.forward, "sdpa": _WhisperSdpaAttention.forward}
110
+ decoder_model = WhisperDecoder(model.get_decoder(), new_layers)
180
111
 
112
+ return decoder_model
181
113
 
182
- class _WhisperDecoderLayer(WhisperDecoderLayer):
183
114
  def forward(
184
115
  self,
185
- hidden_states: torch.Tensor,
186
- attention_mask: Optional[torch.Tensor] = None,
187
- encoder_hidden_states: Optional[torch.Tensor] = None,
188
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
189
- cache_position: Optional[torch.Tensor] = None,
190
- attn_impl: str = "eager",
191
- output_attentions: bool = False,
192
- ) -> torch.Tensor:
193
- # Self Attention Block
194
- residual = hidden_states
195
- hidden_states = self.self_attn_layer_norm(hidden_states)
196
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
116
+ decoder_input_ids: torch.Tensor,
117
+ decoder_attention_mask: torch.Tensor,
118
+ cache_position: torch.Tensor,
119
+ cross_kv_cache: torch.Tensor,
120
+ *self_kv_cache: torch.Tensor,
121
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
122
+ # prepare past_key_values
123
+ self_past_key_values = ()
124
+ cross_past_key_values = ()
125
+ for i in range(0, self.num_layers * 2, 2):
126
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
127
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
197
128
 
198
- hidden_states, _, present_key_value = ATTN_FORWARD_MAP[attn_impl](
199
- self.self_attn,
200
- hidden_states=hidden_states,
201
- past_key_value=self_attn_past_key_value,
202
- attention_mask=attention_mask,
129
+ # Decode
130
+ sequence_output, self_present_key_values, cross_attentions = self.decoder(
131
+ input_ids=decoder_input_ids,
132
+ attention_mask=decoder_attention_mask,
203
133
  cache_position=cache_position,
134
+ self_past_key_values=self_past_key_values,
135
+ cross_past_key_values=cross_past_key_values,
204
136
  )
205
- hidden_states = residual + hidden_states
206
137
 
207
- # Cross-Attention Block
208
- residual = hidden_states
209
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
210
- cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
211
- if output_attentions:
212
- hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
213
- self.encoder_attn,
214
- hidden_states=hidden_states,
215
- key_value_states=encoder_hidden_states,
216
- past_key_value=cross_attn_past_key_value,
217
- cache_position=cache_position,
218
- )
219
- else:
220
- hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
221
- self.encoder_attn,
222
- hidden_states=hidden_states,
223
- key_value_states=encoder_hidden_states,
224
- past_key_value=cross_attn_past_key_value,
225
- cache_position=cache_position,
226
- )
227
- hidden_states = residual + hidden_states
228
- present_key_value = present_key_value + cross_attn_present_key_value
138
+ lm_logits = self.proj_out(sequence_output)
229
139
 
230
- # Fully Connected Block
231
- residual = hidden_states
232
- hidden_states = self.final_layer_norm(hidden_states)
233
- hidden_states = self.activation_fn(self.fc1(hidden_states))
234
- hidden_states = self.fc2(hidden_states)
235
- hidden_states = residual + hidden_states
140
+ outputs = (lm_logits,)
141
+ outputs += self_present_key_values
236
142
 
237
- return hidden_states, present_key_value, cross_attn_weights
143
+ if self.output_attentions:
144
+ # deocder's cross attention is used for token_timestamps
145
+ cross_attention = torch.stack(cross_attentions, dim=0)
146
+ outputs += (cross_attention,)
238
147
 
148
+ return outputs
239
149
 
240
- class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
241
- def forward(self, input_ids, past_key_values_length=0, position_ids=None):
242
- if position_ids is None:
243
- return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
244
- else:
245
- return self.weight[position_ids]
246
150
 
151
+ class WhisperDecoder(nn.Module):
152
+ def __init__(self, model, layers, **kwargs):
153
+ super().__init__()
154
+ self._original_mod = model
155
+ self.layers = nn.ModuleList(layers)
156
+ self.embed_tokens = model.embed_tokens
157
+ self.layer_norm = model.layer_norm
158
+ self.embed_positions = model.embed_positions
247
159
 
248
- class _WhisperDecoder(WhisperDecoder):
249
160
  def forward(
250
161
  self,
251
162
  input_ids: Optional[torch.Tensor] = None,
252
163
  attention_mask: Optional[torch.Tensor] = None,
253
- encoder_hidden_states: Optional[torch.Tensor] = None,
254
- past_key_values: Optional[torch.Tensor] = None,
164
+ self_past_key_values: Optional[torch.Tensor] = None,
165
+ cross_past_key_values: Optional[torch.Tensor] = None,
255
166
  cache_position: Optional[torch.Tensor] = None,
256
- attn_impl: str = "eager",
257
- output_attentions: bool = False,
258
- **kwargs,
259
167
  ):
260
168
  input_shape = input_ids.size()
261
169
  input_ids = input_ids.view(-1, input_shape[-1])
262
170
 
263
171
  # positional embeding
264
172
  inputs_embeds = self.embed_tokens(input_ids)
265
- positions = _WhisperPositionalEmbedding.forward(
266
- self.embed_positions, input_ids, cache_position, cache_position
267
- )
173
+ positions = self.embed_positions(input_ids, position_ids=cache_position)
268
174
  hidden_states = inputs_embeds + positions
269
175
 
270
176
  # prepare casual_attn_mask
271
- if self._use_sdpa:
272
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
273
- attention_mask, input_shape, inputs_embeds, cache_position
274
- )
275
- else:
276
- attention_mask = _prepare_4d_causal_attention_mask(
277
- attention_mask, input_shape, inputs_embeds, cache_position
278
- )
177
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
279
178
 
280
- next_decoder_cache = ()
281
- all_cross_attentions = () if output_attentions else None
179
+ self_present_key_values = ()
180
+ cross_attentions = ()
282
181
  # iterate decoder_layer
283
- for idx, decoder_layer in enumerate(self.layers):
284
- past_key_value = past_key_values[idx] if past_key_values is not None else None
285
- layer_outputs = _WhisperDecoderLayer.forward(
286
- decoder_layer,
182
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
183
+ self_past_key_values, cross_past_key_values, self.layers
184
+ ):
185
+ layer_outputs = decoder_layer(
287
186
  hidden_states,
288
187
  attention_mask=attention_mask,
289
- encoder_hidden_states=encoder_hidden_states,
290
- past_key_value=past_key_value,
188
+ self_past_key_value=self_past_key_value,
189
+ cross_past_key_value=cross_past_key_value,
291
190
  cache_position=cache_position,
292
- attn_impl=attn_impl,
293
- output_attentions=output_attentions,
294
191
  )
295
192
  hidden_states = layer_outputs[0]
193
+ self_present_key_values += layer_outputs[1]
194
+ cross_attentions += (layer_outputs[2],)
296
195
 
297
- next_decoder_cache += (layer_outputs[1],)
298
- if output_attentions:
299
- all_cross_attentions += (layer_outputs[2],)
300
-
301
- # layer_norm
302
196
  hidden_states = self.layer_norm(hidden_states)
303
197
 
304
- return BaseModelOutputWithPastAndCrossAttentions(
305
- last_hidden_state=hidden_states,
306
- past_key_values=next_decoder_cache,
307
- cross_attentions=all_cross_attentions,
308
- )
198
+ return hidden_states, self_present_key_values, cross_attentions
309
199
 
310
200
 
311
- class _WhisperDecoderWrapper(torch.nn.Module):
312
- def __init__(self, model, output_attentions: bool = False):
201
+ class WhisperDecoderLayer(nn.Module):
202
+ def __init__(self, decoder_layer, self_attn, cross_attn):
313
203
  super().__init__()
314
- self.proj_out = model.proj_out
315
- self.config = model.config
316
- self.decoder = model.get_decoder()
317
- self.num_layers = self.config.decoder_layers
318
- self.attn_impl = self.config._attn_implementation
319
- self.output_attentions = output_attentions
204
+ self._original_mod = decoder_layer
205
+ self.self_attn = self_attn
206
+ self.encoder_attn = cross_attn
207
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
208
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
209
+ self.final_layer_norm = decoder_layer.final_layer_norm
210
+ self.activation_fn = decoder_layer.activation_fn
211
+ self.fc1 = decoder_layer.fc1
212
+ self.fc2 = decoder_layer.fc2
320
213
 
321
214
  def forward(
322
215
  self,
323
- decoder_input_ids: torch.Tensor,
324
- decoder_attention_mask: torch.Tensor,
325
- cache_position: torch.Tensor,
326
- self_kv_cache: torch.Tensor,
327
- cross_kv_cache: torch.Tensor,
328
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
329
- # prepare past_key_values
330
- kv_cache = ()
331
- for i in range(0, self.num_layers * 2, 2):
332
- kv_cache = kv_cache + (
333
- (
334
- self_kv_cache[i],
335
- self_kv_cache[i + 1],
336
- cross_kv_cache[i],
337
- cross_kv_cache[i + 1],
338
- ),
339
- )
340
-
341
- # Decode
342
- decoder_outputs = _WhisperDecoder.forward(
343
- self.decoder,
344
- input_ids=decoder_input_ids,
345
- attention_mask=decoder_attention_mask,
216
+ hidden_states: torch.Tensor,
217
+ attention_mask: Optional[torch.Tensor] = None,
218
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
219
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
220
+ cache_position: Optional[torch.Tensor] = None,
221
+ ) -> torch.Tensor:
222
+ # Self Attention Block
223
+ residual = hidden_states
224
+ hidden_states = self.self_attn_layer_norm(hidden_states)
225
+ hidden_states, _, self_present_key_value = self.self_attn(
226
+ hidden_states=hidden_states,
227
+ past_key_value=self_past_key_value,
228
+ attention_mask=attention_mask,
346
229
  cache_position=cache_position,
347
- past_key_values=kv_cache,
348
- encoder_hidden_states=torch.tensor([1]),
349
- attn_impl=self.attn_impl,
350
- output_attentions=self.output_attentions,
351
230
  )
352
- sequence_output = decoder_outputs[0]
353
- lm_logits = self.proj_out(sequence_output)
231
+ hidden_states = residual + hidden_states
354
232
 
355
- # get self_kv_cache from ouputs
356
- past_key_values = decoder_outputs[1]
357
- self_kv_cache = []
358
- for i in range(self.config.decoder_layers):
359
- self_kv_cache.append(past_key_values[i][0])
360
- self_kv_cache.append(past_key_values[i][1])
361
- self_kv_cache = torch.stack(self_kv_cache, dim=0)
233
+ # Cross-Attention Block
234
+ residual = hidden_states
235
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
236
+ hidden_states, cross_attn_weights, cross_present_key_value = self.encoder_attn(
237
+ hidden_states=hidden_states,
238
+ past_key_value=cross_past_key_value,
239
+ )
240
+ hidden_states = residual + hidden_states
362
241
 
363
- if self.output_attentions:
364
- # deocder's cross attention is used for token_timestamps
365
- cross_attention = torch.stack(decoder_outputs[2], dim=0)
366
- return lm_logits, self_kv_cache, cross_attention
367
- else:
368
- return lm_logits, self_kv_cache
242
+ # Fully Connected Block
243
+ residual = hidden_states
244
+ hidden_states = self.final_layer_norm(hidden_states)
245
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
246
+ hidden_states = self.fc2(hidden_states)
247
+ hidden_states = residual + hidden_states
369
248
 
249
+ return hidden_states, self_present_key_value, cross_attn_weights
370
250
 
371
- class _WhisperEncoderWrapper(torch.nn.Module):
372
- def __init__(self, model):
251
+
252
+ class WhisperAttention(nn.Module):
253
+ def __init__(self, attn):
373
254
  super().__init__()
374
- self.model = model
375
- self.config = model.config
376
- self.decoder = model.get_decoder()
377
- self.encoder = model.get_encoder()
378
- self.num_layers = self.config.decoder_layers
379
- self.decoder_max_length = self.config.max_target_positions
380
- self.encoder_max_length = self.config.max_source_positions
381
- self.num_heads = self.config.decoder_attention_heads
382
- self.d_kv = self.config.d_model // self.num_heads
383
- self.attn_impl = self.config._attn_implementation
255
+ self._original_mod = attn
256
+ self.q_proj = attn.q_proj
257
+ self.k_proj = attn.k_proj
258
+ self.v_proj = attn.v_proj
259
+ self.out_proj = attn.out_proj
260
+ self.num_heads = attn.num_heads
261
+ self.embed_dim = attn.embed_dim
262
+ self.head_dim = attn.embed_dim // attn.num_heads
263
+ self.scaling = self.head_dim**-0.5
264
+
265
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
266
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
267
+
268
+
269
+ class WhisperSelfAttention(WhisperAttention):
270
+ def rbln_cache_update(
271
+ self,
272
+ past_key_value: torch.Tensor,
273
+ key_states: torch.Tensor,
274
+ value_states: torch.Tensor,
275
+ cache_position: torch.Tensor,
276
+ ):
277
+ s_idx = torch.tensor(cache_position, dtype=torch.int16)
278
+ axis = torch.tensor(2, dtype=torch.int16)
279
+
280
+ key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
281
+ value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
282
+ return key_states, value_states
384
283
 
385
284
  def forward(
386
285
  self,
387
- input_features: Optional[torch.LongTensor] = None,
388
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
389
- encoder_outputs = self.encoder(input_features=input_features)
286
+ hidden_states: torch.Tensor,
287
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ cache_position: Optional[torch.Tensor] = None,
290
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
291
+ bsz, tgt_len, _ = hidden_states.size()
292
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
293
+ query_states = query_states * self.scaling
390
294
 
391
- last_hidden_states = encoder_outputs[0]
295
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
296
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
297
+ key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
392
298
 
393
- encoder_batch_size = input_features.shape[0]
394
- decoder_batch_size = encoder_batch_size # TODO fix in future
299
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
300
+ attn_weights = attn_weights + attention_mask
301
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
395
302
 
396
- dummy_past_key_value = []
397
- for _ in range(self.num_layers):
398
- pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
399
- pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
400
- pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
401
- pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
402
- layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
403
- dummy_past_key_value.append(layer_pkv)
303
+ attn_output = torch.matmul(attn_weights, value_states)
304
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
305
+ attn_output = attn_output.transpose(1, 2)
306
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
307
+ attn_output = self.out_proj(attn_output)
404
308
 
405
- decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
406
- decoder_attention_mask[:, :1] = 1
309
+ return attn_output, attn_weights, (key_states, value_states)
407
310
 
408
- decoder_outputs = _WhisperDecoder.forward(
409
- self.decoder,
410
- input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
411
- attention_mask=decoder_attention_mask,
412
- cache_position=torch.tensor(0, dtype=torch.int32),
413
- encoder_hidden_states=last_hidden_states,
414
- past_key_values=dummy_past_key_value,
415
- attn_impl=self.attn_impl,
416
- output_attentions=False,
417
- )
418
311
 
419
- first_past_kv = decoder_outputs[1]
312
+ class WhisperCrossAttention(WhisperSelfAttention):
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
317
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
318
+ batch_size, query_len, _ = hidden_states.size()
319
+ query_states = self._shape(self.q_proj(hidden_states), query_len, batch_size)
320
+ query_states = query_states * self.scaling
420
321
 
421
- cross_kv = []
422
- for layer_out in first_past_kv: # for layer
423
- cross_kv.append(layer_out[2])
424
- cross_kv.append(layer_out[3])
425
- cross_kv = torch.stack(cross_kv, dim=0)
322
+ key_states = past_key_value[0]
323
+ value_states = past_key_value[1]
324
+
325
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
326
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
327
+
328
+ attn_output = torch.matmul(attn_weights, value_states)
329
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
330
+ attn_output = attn_output.transpose(1, 2)
331
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
332
+ attn_output = self.out_proj(attn_output)
426
333
 
427
- return cross_kv
334
+ return attn_output, attn_weights, (key_states, value_states)
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  import functools
26
25
  import glob
27
26
  import os