optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -21,494 +21,152 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import TYPE_CHECKING, Optional, Tuple
24
+ from typing import Tuple
25
25
 
26
26
  import torch
27
27
  from torch import nn
28
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
29
- from transformers.models.t5.configuration_t5 import T5Config
30
- from transformers.models.t5.modeling_t5 import (
31
- T5Attention,
32
- T5Block,
33
- T5LayerCrossAttention,
34
- T5LayerSelfAttention,
35
- T5Stack,
36
- )
37
28
  from transformers.utils import logging
38
29
 
30
+ from ....ops import register_rbln_custom_attention_add_softmax
31
+ from ..seq2seq.seq2seq_architecture import (
32
+ Seq2SeqDecoder,
33
+ Seq2SeqDecoderLayer,
34
+ Seq2SeqDecoderWrapper,
35
+ Seq2SeqEncoderWrapper,
36
+ Seq2SeqForConditionalGeneration,
37
+ Seq2SeqSelfAttention,
38
+ )
39
39
 
40
- logger = logging.get_logger(__name__)
41
40
 
42
- if TYPE_CHECKING:
43
- from transformers import T5ForConditionalGeneration
41
+ logger = logging.get_logger(__name__)
44
42
 
45
43
 
46
44
  class T5Wrapper:
47
- def __init__(self, model):
48
- self.encoder = T5EncoderWrapper(model)
49
- self.decoder = T5DecoderWrapper(model)
45
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, dec_max_seq_len: int = None):
46
+ self.encoder = T5EncoderWrapper(model, enc_max_seq_len)
47
+ self.decoder = T5DecoderWrapper(model, dec_max_seq_len=dec_max_seq_len)
48
+
49
+
50
+ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
51
+ def __post_init__(self, model: nn.Module):
52
+ self.n_layer = getattr(self.config, "num_layers")
53
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
54
+ self.num_heads = self.config.num_heads
55
+ self.d_kv = self.config.d_kv
56
+
57
+ def _extract_cross_kv_projects(self, t5_block: nn.Module):
58
+ return (
59
+ # different from bart
60
+ nn.ModuleList(t5_block[i].layer[1].EncDecAttention.k for i in range(self.n_layer)),
61
+ nn.ModuleList(t5_block[i].layer[1].EncDecAttention.v for i in range(self.n_layer)),
62
+ )
50
63
 
51
64
 
52
- class T5Encoder(T5Stack):
53
- def forward(
54
- self,
55
- input_ids: torch.Tensor,
56
- attention_mask: torch.Tensor,
57
- position_bias: torch.Tensor,
58
- batch_ids: torch.Tensor = None,
59
- ) -> BaseModelOutput:
60
- hidden_states = self.embed_tokens(input_ids)
61
- extended_attention_mask = self.invert_attention_mask(attention_mask)
62
- position_bias = position_bias + extended_attention_mask
63
- for i, layer_module in enumerate(self.block):
64
- layer_outputs = _T5Block.forward(
65
- layer_module,
66
- hidden_states,
67
- position_bias=position_bias,
68
- batch_ids=batch_ids,
69
- )
70
- hidden_states = layer_outputs[0]
71
- hidden_states = self.final_layer_norm(hidden_states)
72
- return BaseModelOutput(last_hidden_state=hidden_states)
73
-
74
-
75
- class T5Decoder(T5Stack):
76
- def forward(
77
- self,
78
- input_ids: torch.Tensor,
79
- attention_mask: torch.Tensor,
80
- encoder_hidden_states: torch.Tensor,
81
- encoder_attention_mask: torch.Tensor,
82
- past_key_values: torch.Tensor,
83
- position_bias: torch.Tensor,
84
- encoder_decoder_position_bias: torch.Tensor,
85
- cache_position: torch.Tensor,
86
- batch_ids: torch.Tensor,
87
- ) -> BaseModelOutputWithPastAndCrossAttentions:
88
- hidden_states = self.embed_tokens(input_ids)
89
- extended_attention_mask = self.invert_attention_mask(attention_mask)
90
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
91
-
92
- position_bias = position_bias + extended_attention_mask
93
- encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
94
-
95
- present_key_value_states = ()
96
-
97
- for layer_module, past_key_value in zip(self.block, past_key_values):
98
- layer_outputs = _T5Block.forward(
99
- layer_module,
100
- hidden_states,
101
- position_bias=position_bias,
102
- encoder_hidden_states=encoder_hidden_states,
103
- encoder_decoder_position_bias=encoder_decoder_position_bias,
104
- past_key_value=past_key_value,
105
- cache_position=cache_position,
106
- batch_ids=batch_ids,
107
- )
108
- hidden_states, present_key_value_state = layer_outputs[:2]
109
- present_key_value_states = present_key_value_states + (present_key_value_state,)
110
-
111
- hidden_states = self.final_layer_norm(hidden_states)
112
-
113
- return BaseModelOutputWithPastAndCrossAttentions(
114
- last_hidden_state=hidden_states,
115
- past_key_values=present_key_value_states,
116
- )
65
+ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
66
+ def __post_init__(self, model, dec_max_seq_len: int = None):
67
+ register_rbln_custom_attention_add_softmax()
68
+ self.num_layers = self.config.num_layers
69
+ self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
117
70
 
71
+ def convert_to_rbln_conditional_generation(self, model: nn.Module, dec_max_seq_len: int):
72
+ new_blocks = []
73
+ for block in model.get_decoder().block:
74
+ self_attn = T5LayerSelfAttention(block.layer[0].SelfAttention)
75
+ block = T5Block(block, self_attn)
76
+ new_blocks.append(block)
118
77
 
119
- class T5EncoderWrapper(torch.nn.Module):
120
- def __init__(self, model: "T5ForConditionalGeneration"):
121
- super().__init__()
122
- self.config = model.config
123
- self.model = model
124
- self.encoder = model.encoder
125
- self.decoder = model.decoder
126
- self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
127
- self.config, "max_position_embeddings", None
128
- )
129
- self.encoder_max_length = None
130
- self.decoder_max_length = None
78
+ decoder_model = T5Decoder(model.get_decoder(), new_blocks, dec_max_seq_len=dec_max_seq_len)
79
+ new_model = T5ForConditionalGeneration(model, decoder_model)
131
80
 
132
- def forward(
133
- self,
134
- input_ids: torch.Tensor,
135
- attention_mask: torch.Tensor,
136
- cross_key_value: torch.Tensor = None,
137
- batch_idx: torch.Tensor = None,
138
- ) -> torch.Tensor:
139
- decoder_max_length = self.decoder_max_length or self.default_max_length
140
- encoder_max_length = self.encoder_max_length or self.default_max_length
141
-
142
- attn_layer = self.encoder.block[0].layer[0].SelfAttention
143
- encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
144
- encoder_outputs = T5Encoder.forward(
145
- self.encoder,
146
- input_ids,
147
- attention_mask,
148
- encoder_position_bias,
149
- batch_ids=torch.tensor(0, dtype=torch.int32),
150
- )
81
+ return new_model
151
82
 
152
- attn_layer = self.decoder.block[0].layer[0].SelfAttention
153
- decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
154
- decoder_position_bias = decoder_position_bias[:, :, :1]
155
-
156
- attn_layer = self.decoder.block[0].layer[1].EncDecAttention
157
- encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
158
-
159
- dummy_past_key_value = []
160
- for i in range(self.config.num_layers):
161
- pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
162
- pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
163
- pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
164
- pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
165
- layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
166
- dummy_past_key_value.append(layer_pkv)
167
-
168
- decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
169
- decoder_attention_mask[:, :1] = 1
170
-
171
- # Since first step of decoder has different graph to further step of it,
172
- # here we merges decoder into its corresponding encoder.
173
- # TODO(jongho): Separate first-step-decoder.
174
- decoder_outputs = T5Decoder.forward(
175
- self.decoder,
176
- input_ids=torch.zeros(1, 1, dtype=torch.int64),
177
- attention_mask=decoder_attention_mask,
178
- position_bias=decoder_position_bias,
179
- encoder_decoder_position_bias=encoder_decoder_position_bias,
180
- encoder_hidden_states=encoder_outputs.last_hidden_state,
181
- encoder_attention_mask=attention_mask,
182
- past_key_values=dummy_past_key_value,
183
- cache_position=torch.tensor(0, dtype=torch.int32),
184
- batch_ids=torch.tensor(0, dtype=torch.int32),
185
- )
186
83
 
187
- past_key_values = decoder_outputs.past_key_values
84
+ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
85
+ has_rescaling = True
188
86
 
189
- cross_kv_cache = []
190
- for i in range(self.model.config.num_layers):
191
- cross_kv_cache.append(past_key_values[i][2])
192
- cross_kv_cache.append(past_key_values[i][3])
193
- cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
87
+ def __post_init__(self):
88
+ self.scaling = self.config.d_model**-0.5
194
89
 
195
- cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
196
90
 
197
- return cross_key_value
91
+ class T5Decoder(Seq2SeqDecoder):
92
+ has_pos_emb = False
198
93
 
94
+ def __post_init__(self, dec_max_seq_len: int = None):
95
+ self.invert_attention_mask = self._original_mod.invert_attention_mask
96
+ self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
199
97
 
200
- class T5DecoderWrapper(torch.nn.Module):
201
- def __init__(self, model: "T5ForConditionalGeneration"):
202
- super().__init__()
203
- self.config = model.config
204
- self.model = model
205
- self.encoder = model.encoder
206
- self.decoder = model.decoder
207
- self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
208
- self.config, "max_position_embeddings", None
209
- )
210
- self.encoder_max_length = None
211
- self.decoder_max_length = None
98
+ def precompute_dec_position_bias(self, model, dec_max_length):
99
+ attn_layer = model.block[0].layer[0].SelfAttention
100
+ return attn_layer.compute_bias(dec_max_length, dec_max_length)
212
101
 
213
- def forward(
214
- self,
215
- input_ids: torch.Tensor,
216
- attention_mask: torch.Tensor,
217
- encoder_attention_mask: torch.Tensor,
218
- cache_position: torch.Tensor,
219
- batch_position: torch.Tensor,
220
- self_kv_cache: torch.Tensor,
221
- cross_kv_cache: torch.Tensor,
222
- ) -> Tuple[torch.Tensor]:
223
- # cache_position : step 0부터
224
- # attention_mask : 1개가 색칠된것부터 ([0:cache_position+1])
225
- num_layers = self.model.config.num_layers
226
- encoder_max_length = self.encoder_max_length or self.default_max_length
227
- decoder_max_length = self.decoder_max_length or self.default_max_length
228
-
229
- if input_ids.shape[1] == 1:
230
- rbln_batch_position = None
231
- else:
232
- rbln_batch_position = batch_position
233
-
234
- kv_cache = ()
235
- for i in range(0, num_layers * 2, 2):
236
- kv_cache = kv_cache + (
237
- (
238
- self_kv_cache[i],
239
- self_kv_cache[i + 1],
240
- cross_kv_cache[i],
241
- cross_kv_cache[i + 1],
242
- ),
243
- )
244
-
245
- attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
246
- _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
247
-
248
- # position_bias need to compute with batch (for cb)
102
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, cache_position):
103
+ attention_mask = self.invert_attention_mask(attention_mask)
104
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
105
+
106
+ b_size = attention_mask.shape[0]
249
107
  batch_decoder_position_bias = []
250
- for i in range(input_ids.shape[0]):
251
- batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
108
+ for i in range(b_size):
109
+ batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
252
110
  batch_decoder_position_bias.append(batch_position_bias)
253
- decoder_position_bias = torch.cat(batch_decoder_position_bias, dim=0)
254
-
255
- attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
256
- encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
257
-
258
- decoder_outputs = T5Decoder.forward(
259
- self.model.decoder,
260
- input_ids=input_ids,
261
- attention_mask=attention_mask,
262
- encoder_hidden_states=1,
263
- encoder_attention_mask=encoder_attention_mask,
264
- position_bias=decoder_position_bias,
265
- encoder_decoder_position_bias=encoder_decoder_position_bias,
266
- past_key_values=kv_cache,
267
- cache_position=cache_position,
268
- batch_ids=rbln_batch_position,
269
- )
111
+ position_bias = torch.cat(batch_decoder_position_bias, dim=0)
270
112
 
271
- past_key_values = decoder_outputs.past_key_values
272
- sequence_output = decoder_outputs[0]
273
- if self.model.config.tie_word_embeddings:
274
- # Rescale output before projecting on vocab
275
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
276
- sequence_output = sequence_output * (self.model.model_dim**-0.5)
277
- lm_logits = self.model.lm_head(sequence_output)
113
+ attention_mask = position_bias + attention_mask
278
114
 
279
- self_kv_cache = []
280
- for i in range(self.model.config.num_layers):
281
- self_kv_cache.append(past_key_values[i][0])
282
- self_kv_cache.append(past_key_values[i][1])
115
+ return attention_mask, encoder_attention_mask
283
116
 
284
- self_kv_cache = torch.stack(self_kv_cache, dim=0)
285
117
 
286
- return lm_logits, self_kv_cache, batch_position
118
+ class T5Block(Seq2SeqDecoderLayer):
119
+ def __post_init__(self):
120
+ self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
121
+ self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
122
+ self.encoder_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
123
+ self.ff_layer = self._original_mod.layer[2]
287
124
 
125
+ def pre_self_attn_layer_norm(self, hidden_states):
126
+ return self.self_attn_layer_norm(hidden_states)
288
127
 
289
- class _T5Attention(T5Attention):
290
- def __init__(self, config: T5Config, has_relative_attention_bias=False):
291
- super().__init__(config, has_relative_attention_bias)
128
+ def post_self_attn_layer_norm(self, hidden_states):
129
+ return hidden_states
292
130
 
293
- def forward(
294
- self,
295
- hidden_states: torch.Tensor,
296
- key_value_states: Tuple[torch.Tensor] = None,
297
- position_bias: torch.Tensor = None,
298
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
299
- cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
300
- batch_index: torch.Tensor = None,
301
- is_self_attn: Optional[bool] = None,
302
- ) -> Tuple[torch.Tensor]:
303
- batch_size = hidden_states.shape[0]
304
-
305
- def shape(states, batch_size):
306
- """projection"""
307
- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
308
-
309
- def unshape(states, batch_size):
310
- """reshape"""
311
- return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
312
-
313
- query_states = shape(self.q(hidden_states), batch_size) # (batch_size, n_heads, seq_length, dim_per_head)
314
-
315
- # projection
316
- if is_self_attn:
317
- key_states = shape(self.k(hidden_states), batch_size)
318
- value_states = shape(self.v(hidden_states), batch_size)
319
- else:
320
- # cross-attn
321
- if cache_position.dim() == 0:
322
- key_states = shape(self.k(key_value_states), key_value_states.shape[0])
323
- value_states = shape(self.v(key_value_states), key_value_states.shape[0])
324
- past_key_value = key_states, value_states
325
- else:
326
- key_states = past_key_value[0]
327
- value_states = past_key_value[1]
328
-
329
- if (batch_index is None or batch_index == -1) and batch_size > 1:
330
- all_key_states = []
331
- all_value_states = []
332
- all_attn_output = []
333
-
334
- for b in range(batch_size):
335
- batch_query_states = query_states[b].unsqueeze(0)
336
- batch_key_states = key_states[b].unsqueeze(0)
337
- batch_value_states = value_states[b].unsqueeze(0)
338
-
339
- if is_self_attn and past_key_value is not None:
340
- batch_key_states = (
341
- past_key_value[0][b]
342
- .unsqueeze(0)
343
- .slice_scatter(
344
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
345
- )
346
- )
347
- batch_value_states = (
348
- past_key_value[1][b]
349
- .unsqueeze(0)
350
- .slice_scatter(
351
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
352
- )
353
- )
354
-
355
- scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
356
- scores += position_bias[b]
357
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
358
- attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
359
- all_key_states.append(batch_key_states)
360
- all_value_states.append(batch_value_states)
361
- all_attn_output.append(attn_output)
362
-
363
- key_states = torch.cat(all_key_states, dim=0)
364
- value_states = torch.cat(all_value_states, dim=0)
365
- attn_output = torch.cat(all_attn_output, dim=0)
366
-
367
- else:
368
- if batch_index is None or batch_index == -1:
369
- batch_index = 0
370
-
371
- if is_self_attn and past_key_value is not None:
372
- key_states = past_key_value[0].slice_scatter(
373
- key_states, dim=2, start=cache_position, end=cache_position + 1
374
- )
375
- value_states = past_key_value[1].slice_scatter(
376
- value_states, dim=2, start=cache_position, end=cache_position + 1
377
- )
378
- # compute scores
379
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
380
- scores += position_bias
381
-
382
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
383
- scores
384
- ) # (batch_size, n_heads, seq_length, key_length)
385
-
386
- attn_output = unshape(
387
- torch.matmul(attn_weights, value_states), batch_size
388
- ) # (batch_size, seq_length, dim)
389
-
390
- attn_output = self.o(attn_output)
391
- present_key_value = (key_states, value_states)
392
- outputs = (attn_output,) + (present_key_value,)
393
- return outputs
394
-
395
-
396
- class _T5LayerSelfAttention(T5LayerSelfAttention):
397
- def forward(
398
- self,
399
- hidden_states: torch.Tensor,
400
- position_bias: torch.Tensor = None,
401
- past_key_value: Tuple[torch.Tensor] = None,
402
- cache_position: Optional[torch.Tensor] = None,
403
- batch_index: torch.Tensor = None,
404
- ):
405
- normed_hidden_states = self.layer_norm(hidden_states)
406
- attention_output = _T5Attention.forward(
407
- self.SelfAttention,
408
- hidden_states=normed_hidden_states,
409
- position_bias=position_bias,
410
- past_key_value=past_key_value,
411
- cache_position=cache_position,
412
- batch_index=batch_index,
413
- is_self_attn=True,
414
- )
131
+ def pre_cross_attn_layer_norm(self, hidden_states):
132
+ return self.encoder_attn_layer_norm(hidden_states)
415
133
 
416
- # Residual Connection
417
- hidden_states = hidden_states + self.dropout(attention_output[0])
418
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
419
- return outputs
134
+ def post_cross_attn_layer_norm(self, hidden_states):
135
+ return hidden_states
420
136
 
421
137
 
422
- class _T5LayerCrossAttention(T5LayerCrossAttention):
423
- def forward(
424
- self,
425
- hidden_states: torch.Tensor,
426
- key_value_states: torch.Tensor,
427
- position_bias: torch.Tensor = None,
428
- past_key_value: Tuple[torch.Tensor] = None,
429
- cache_position: Optional[torch.Tensor] = None,
430
- batch_index: torch.Tensor = None,
431
- ):
432
- normed_hidden_states = self.layer_norm(hidden_states)
433
- attention_output = _T5Attention.forward(
434
- self.EncDecAttention,
435
- hidden_states=normed_hidden_states,
436
- key_value_states=key_value_states,
437
- position_bias=position_bias,
438
- past_key_value=past_key_value,
439
- cache_position=cache_position,
440
- batch_index=batch_index,
441
- is_self_attn=False,
442
- )
138
+ class T5LayerSelfAttention(Seq2SeqSelfAttention):
139
+ def __post_init__(self):
140
+ self.q_proj = self._original_mod.q
141
+ self.k_proj = self._original_mod.k
142
+ self.v_proj = self._original_mod.v
143
+ self.out_proj = self._original_mod.o
144
+ self.num_heads = self._original_mod.n_heads
145
+ self.head_dim = self._original_mod.key_value_proj_dim
146
+ self.attn_decode = torch.ops.rbln_custom_ops.attn_decode_add_softmax
443
147
 
444
- # Residual connection
445
- layer_output = hidden_states + self.dropout(attention_output[0])
446
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
447
- return outputs
148
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
149
+ query_states = self.q_proj(hidden_states)
150
+ key_states = self.k_proj(hidden_states)
151
+ value_states = self.v_proj(hidden_states)
152
+ return query_states, key_states, value_states
448
153
 
449
154
 
450
- class _T5Block(T5Block):
155
+ class T5CrossAttention(nn.Module):
156
+ def __init__(self, attn):
157
+ super().__init__()
158
+ self.attn = attn
159
+
451
160
  def forward(
452
161
  self,
453
- hidden_states,
454
- position_bias=None,
455
- encoder_hidden_states=None,
456
- encoder_decoder_position_bias=None,
457
- past_key_value=None,
458
- cache_position=None,
459
- batch_ids=None,
162
+ hidden_states: torch.Tensor = None,
163
+ past_key_value: torch.Tensor = None,
164
+ attention_mask: torch.Tensor = None,
165
+ key_value_states: torch.Tensor = None,
460
166
  ):
461
- if past_key_value is not None:
462
- if not self.is_decoder:
463
- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
464
- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
465
-
466
- if len(past_key_value) != expected_num_past_key_values:
467
- raise ValueError(
468
- f"There should be {expected_num_past_key_values} past states. "
469
- f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
470
- f"Got {len(past_key_value)} past key / value states"
471
- )
472
-
473
- self_attn_past_key_value = past_key_value[:2]
474
- if self_attn_past_key_value == (None, None):
475
- self_attn_past_key_value = None
476
-
477
- cross_attn_past_key_value = past_key_value[2:]
478
- else:
479
- self_attn_past_key_value, cross_attn_past_key_value = None, None
480
- self_attention_outputs = _T5LayerSelfAttention.forward(
481
- self.layer[0],
167
+ return self.attn(
482
168
  hidden_states=hidden_states,
483
- position_bias=position_bias,
484
- past_key_value=self_attn_past_key_value,
485
- cache_position=cache_position,
486
- batch_index=batch_ids,
169
+ past_key_value=past_key_value,
170
+ position_bias=attention_mask,
171
+ key_value_states=key_value_states,
487
172
  )
488
-
489
- hidden_states, present_key_value_state = self_attention_outputs[:2]
490
-
491
- do_cross_attention = self.is_decoder and encoder_hidden_states is not None
492
- if do_cross_attention:
493
- cross_attention_outputs = _T5LayerCrossAttention.forward(
494
- self.layer[1],
495
- hidden_states,
496
- key_value_states=encoder_hidden_states,
497
- position_bias=encoder_decoder_position_bias,
498
- past_key_value=cross_attn_past_key_value,
499
- cache_position=cache_position,
500
- batch_index=batch_ids,
501
- )
502
- hidden_states = cross_attention_outputs[0]
503
- # Combine self attn and cross attn key value states
504
- if present_key_value_state is not None:
505
- # print(present_key_value_state.shape)
506
- present_key_value_state = present_key_value_state + cross_attention_outputs[1]
507
-
508
- # Apply Feed Forward layer
509
- hidden_states = self.layer[-1](hidden_states)
510
-
511
- outputs = (hidden_states,)
512
- outputs = outputs + (present_key_value_state,)
513
-
514
- return outputs
@@ -1,3 +1,45 @@
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2024 Rebellions Inc.
16
+
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at:
20
+
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ # Portions of this software are licensed under the Apache License,
30
+ # Version 2.0. See the NOTICE file distributed with this work for
31
+ # additional information regarding copyright ownership.
32
+
33
+ # All other portions of this software, including proprietary code,
34
+ # are the intellectual property of Rebellions Inc. and may not be
35
+ # copied, modified, or distributed without prior written permission
36
+ # from Rebellions Inc.
37
+
38
+ """
39
+ Generation utilities for Whisper.
40
+ Modified from `transformers.models.whisper.generation_whisper.py`
41
+ """
42
+
1
43
  import torch
2
44
  from transformers import GenerationMixin
3
45
  from transformers.models.whisper.generation_whisper import WhisperGenerationMixin