optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -21,497 +21,140 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple
24
+ from typing import Tuple
25
25
 
26
26
  import torch
27
27
  from torch import nn
28
28
  from transformers.modeling_attn_mask_utils import (
29
29
  _prepare_4d_attention_mask,
30
- _prepare_4d_attention_mask_for_sdpa,
31
- _prepare_4d_causal_attention_mask,
32
- _prepare_4d_causal_attention_mask_for_sdpa,
33
- )
34
- from transformers.modeling_outputs import (
35
- BaseModelOutputWithPastAndCrossAttentions,
36
- )
37
- from transformers.models.bart.modeling_bart import (
38
- BartAttention,
39
- BartDecoder,
40
- BartDecoderLayer,
41
- BartForConditionalGeneration,
42
- BartSdpaAttention,
43
30
  )
44
31
  from transformers.utils import logging
45
32
 
33
+ from ..seq2seq.seq2seq_architecture import (
34
+ Seq2SeqDecoder,
35
+ Seq2SeqDecoderLayer,
36
+ Seq2SeqDecoderWrapper,
37
+ Seq2SeqEncoderWrapper,
38
+ Seq2SeqForConditionalGeneration,
39
+ Seq2SeqSelfAttention,
40
+ )
41
+
46
42
 
47
43
  logger = logging.get_logger(__name__)
48
44
 
49
45
 
50
46
  class BartWrapper:
51
- def __init__(self, model):
52
- self.encoder = BartEncoderWrapper(model)
47
+ def __init__(self, model: nn.Module, enc_max_seq_len: int):
48
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
53
49
  self.decoder = BartDecoderWrapper(model)
54
50
 
55
51
 
56
- class _BartAttention(BartAttention):
57
- def forward(
58
- self,
59
- hidden_states: torch.Tensor,
60
- past_key_value: Tuple[torch.Tensor],
61
- attention_mask: torch.Tensor,
62
- cache_position: torch.Tensor,
63
- batch_index: torch.Tensor,
64
- key_value_states: Optional[torch.Tensor] = None,
65
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
66
- bsz, tgt_len, _ = hidden_states.size()
67
- is_cross_attention = key_value_states is not None
52
+ class BartDecoderWrapper(Seq2SeqDecoderWrapper):
53
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
54
+ new_layers = []
55
+ for layer in model.get_decoder().layers:
56
+ self_attn = BartSelfAttention(layer.self_attn)
57
+ new_layers.append(BartDecoderLayer(layer, self_attn))
68
58
 
69
- query_states = self.q_proj(hidden_states) * self.scaling
59
+ decoder_model = BartDecoder(model.get_decoder(), new_layers)
60
+ new_model = BartForConditionalGeneration(model, decoder_model)
70
61
 
71
- if is_cross_attention:
72
- is_dummy_decoder = len(key_value_states.shape) > 1
73
- if is_dummy_decoder:
74
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
75
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
76
- else:
77
- key_states = past_key_value[0]
78
- value_states = past_key_value[1]
79
- else:
80
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
81
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
82
-
83
- if cache_position.dim() > 0:
84
- proj_shape = (bsz, self.num_heads, -1, self.head_dim)
85
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
86
- key_states = key_states.reshape(*proj_shape)
87
- value_states = value_states.reshape(*proj_shape)
88
-
89
- all_key_states = []
90
- all_value_states = []
91
- all_attn_output = []
92
- for b in range(bsz):
93
- batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
94
- batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
95
- batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
96
- batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
97
- if not is_cross_attention:
98
- batch_key_states = (
99
- past_key_value[0][b]
100
- .unsqueeze(0)
101
- .unsqueeze(2)
102
- .slice_scatter(
103
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
104
- )
105
- )
106
- batch_value_states = (
107
- past_key_value[1][b]
108
- .unsqueeze(0)
109
- .unsqueeze(2)
110
- .slice_scatter(
111
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
112
- )
113
- )
114
- attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
115
- attn_weights = attn_weights + batch_attention_mask
116
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
117
-
118
- attn_output = torch.matmul(attn_weights, batch_value_states)
119
- attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
120
- attn_output = attn_output.transpose(1, 2)
121
- attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
122
- all_key_states.append(batch_key_states)
123
- all_value_states.append(batch_value_states)
124
- all_attn_output.append(attn_output)
125
- key_states = torch.cat(all_key_states, dim=0).squeeze(2)
126
- value_states = torch.cat(all_value_states, dim=0).squeeze(2)
127
- attn_output = torch.cat(all_attn_output, dim=0)
62
+ return new_model
128
63
 
129
- else:
130
- if batch_index is None or batch_index == -1:
131
- batch_index = 0
132
-
133
- if not is_cross_attention:
134
- key_states = past_key_value[0].slice_scatter(
135
- key_states, dim=2, start=cache_position, end=cache_position + 1
136
- )
137
- value_states = past_key_value[1].slice_scatter(
138
- value_states, dim=2, start=cache_position, end=cache_position + 1
139
- )
140
-
141
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
142
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
143
- key_states = key_states.reshape(*proj_shape)
144
- value_states = value_states.reshape(*proj_shape)
145
-
146
- src_len = key_states.size(1)
147
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
148
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
149
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
150
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
151
-
152
- attn_output = torch.bmm(attn_weights, value_states)
153
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
154
- attn_output = attn_output.transpose(1, 2)
155
- key_states = key_states.unsqueeze(0)
156
- value_states = value_states.unsqueeze(0)
157
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
158
-
159
- attn_output = self.out_proj(attn_output)
160
-
161
- present_key_value = (key_states, value_states)
162
-
163
- return attn_output, present_key_value
164
-
165
-
166
- class _BartSdpaAttention(BartSdpaAttention):
167
- def forward(
168
- self,
169
- hidden_states: torch.Tensor,
170
- past_key_value: Tuple[torch.Tensor],
171
- attention_mask: torch.Tensor,
172
- cache_position: torch.Tensor,
173
- batch_index: torch.Tensor,
174
- key_value_states: Optional[torch.Tensor] = None,
175
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
176
- bsz, tgt_len, _ = hidden_states.size()
177
- is_cross_attention = key_value_states is not None
178
-
179
- query_states = self.q_proj(hidden_states)
180
-
181
- if is_cross_attention:
182
- is_dummy_decoder = len(key_value_states.shape) > 1
183
- if is_dummy_decoder:
184
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
185
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
186
- else:
187
- key_states = past_key_value[0]
188
- value_states = past_key_value[1]
189
- else:
190
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
191
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
192
-
193
- query_states = self._shape(query_states, tgt_len, bsz)
194
-
195
- if (batch_index is None or batch_index == -1) and bsz > 1:
196
- all_key_states = []
197
- all_value_states = []
198
- all_attn_output = []
199
-
200
- for b in range(bsz):
201
- batch_query_states = query_states[b].unsqueeze(0)
202
- batch_attention_mask = attention_mask[b].unsqueeze(0)
203
- batch_key_states = key_states[b].unsqueeze(0)
204
- batch_value_states = value_states[b].unsqueeze(0)
205
-
206
- if not is_cross_attention:
207
- batch_key_states = (
208
- past_key_value[0][b]
209
- .unsqueeze(0)
210
- .slice_scatter(
211
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
212
- )
213
- )
214
- batch_value_states = (
215
- past_key_value[1][b]
216
- .unsqueeze(0)
217
- .slice_scatter(
218
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
219
- )
220
- )
221
-
222
- attn_output = torch.nn.functional.scaled_dot_product_attention(
223
- batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
224
- )
225
- attn_output = attn_output.transpose(1, 2)
226
- attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
227
- all_key_states.append(batch_key_states)
228
- all_value_states.append(batch_value_states)
229
- all_attn_output.append(attn_output)
230
-
231
- key_states = torch.cat(all_key_states, dim=0)
232
- value_states = torch.cat(all_value_states, dim=0)
233
- attn_output = torch.cat(all_attn_output, dim=0)
234
64
 
65
+ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
66
+ has_rescaling = False
67
+
68
+ def __post_init__(self):
69
+ self.scaling = self.config.d_model**-0.5
70
+
71
+
72
+ class BartDecoder(Seq2SeqDecoder):
73
+ has_pos_emb = True
74
+
75
+ def __post_init__(self):
76
+ self.embed_positions = self._original_mod.embed_positions
77
+ self.layernorm_embedding = self._original_mod.layernorm_embedding
78
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
79
+
80
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
81
+ attention_mask = attention_mask[:, None, None, :]
82
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
83
+
84
+ return attention_mask, encoder_attention_mask
85
+
86
+ def apply_position_embedding(self, inputs_embeds, cache_position):
87
+ hidden_all = []
88
+ for i in range(inputs_embeds.shape[0]):
89
+ positions_idx = cache_position[i]
90
+ position_weight = self.embed_positions.weight[2:]
91
+ position = position_weight[positions_idx]
92
+ batch_hidden = position + inputs_embeds[i]
93
+ hidden_all.append(batch_hidden)
94
+ hidden_states = torch.stack(hidden_all, dim=0)
95
+
96
+ hidden_states = self.layernorm_embedding(hidden_states)
97
+
98
+ return hidden_states
99
+
100
+ def get_embedding(self):
101
+ if self.embed_scale is not None:
102
+ return lambda x: self.embed_tokens(x) * self.embed_scale
235
103
  else:
236
- if batch_index is None or batch_index == -1:
237
- batch_index = 0
238
-
239
- if not is_cross_attention:
240
- key_states = past_key_value[0].slice_scatter(
241
- key_states, dim=2, start=cache_position, end=cache_position + 1
242
- )
243
- value_states = past_key_value[1].slice_scatter(
244
- value_states, dim=2, start=cache_position, end=cache_position + 1
245
- )
246
-
247
- # need 4d shape (input tensors) for scaled_dot_product_attention
248
- attn_output = torch.nn.functional.scaled_dot_product_attention(
249
- query_states,
250
- key_states,
251
- value_states,
252
- attn_mask=attention_mask,
253
- )
254
- attn_output = attn_output.transpose(1, 2)
255
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
256
-
257
- attn_output = self.out_proj(attn_output)
258
-
259
- present_key_value = (key_states, value_states)
260
-
261
- return attn_output, present_key_value
262
-
263
-
264
- ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
265
-
266
-
267
- class _BartDecoderLayer(BartDecoderLayer):
268
- def forward(
269
- self,
270
- hidden_states: torch.Tensor,
271
- attention_mask: torch.Tensor,
272
- encoder_attention_mask: torch.Tensor,
273
- encoder_hidden_states: torch.Tensor,
274
- past_key_value: Tuple[torch.Tensor],
275
- cache_position: torch.Tensor,
276
- batch_ids: torch.Tensor,
277
- attn_impl: str = "eager",
278
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
279
- # Self Attention Block
280
- residual = hidden_states
281
- self_attn_past_key_value = past_key_value[:2]
282
-
283
- hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
284
- self.self_attn,
285
- hidden_states=hidden_states,
286
- past_key_value=self_attn_past_key_value,
287
- attention_mask=attention_mask,
288
- cache_position=cache_position,
289
- batch_index=batch_ids,
290
- )
291
- hidden_states = residual + hidden_states
292
- hidden_states = self.self_attn_layer_norm(hidden_states)
104
+ return self.embed_tokens
293
105
 
294
- # Cross-Attention Block
295
- residual = hidden_states
296
- cross_attn_past_key_value = past_key_value[-2:]
297
-
298
- hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
299
- self.encoder_attn,
300
- hidden_states=hidden_states,
301
- key_value_states=encoder_hidden_states,
302
- past_key_value=cross_attn_past_key_value,
303
- attention_mask=encoder_attention_mask,
304
- cache_position=cache_position,
305
- batch_index=batch_ids,
306
- )
307
- hidden_states = residual + hidden_states
308
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
309
- present_key_value = present_key_value + cross_attn_present_key_value
310
106
 
311
- # Fully Connected Block
107
+ class BartLayerFF(nn.Module):
108
+ def __init__(self, decoder_layer):
109
+ super().__init__()
110
+ self.fc1 = decoder_layer.fc1
111
+ self.fc2 = decoder_layer.fc2
112
+ self.activation_fn = decoder_layer.activation_fn
113
+ self.layer_norm = decoder_layer.final_layer_norm
114
+
115
+ def forward(self, hidden_states):
116
+ # Residual Connection
312
117
  residual = hidden_states
313
118
  hidden_states = self.activation_fn(self.fc1(hidden_states))
314
119
  hidden_states = self.fc2(hidden_states)
315
120
  hidden_states = residual + hidden_states
316
- hidden_states = self.final_layer_norm(hidden_states)
317
-
318
- return hidden_states, present_key_value
319
-
320
-
321
- class _BartDecoder(BartDecoder):
322
- def forward(
323
- self,
324
- input_ids: torch.Tensor,
325
- attention_mask: torch.Tensor,
326
- encoder_attention_mask: torch.Tensor,
327
- encoder_hidden_states: torch.Tensor,
328
- past_key_values: torch.Tensor,
329
- cache_position: torch.Tensor,
330
- batch_ids: torch.Tensor,
331
- attn_impl: str = "eager",
332
- ):
333
- # embedding
334
- if hasattr(self, "embed_scale"):
335
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
336
- else:
337
- inputs_embeds = self.embed_tokens(input_ids)
121
+ hidden_states = self.layer_norm(hidden_states)
122
+ return hidden_states
338
123
 
339
- if cache_position.dim() == 0:
340
- positions_idx = cache_position + self.embed_positions.offset
341
- positions = self.embed_positions.weight[positions_idx]
342
- hidden_states = inputs_embeds + positions
343
- else:
344
- hidden_all = []
345
- # compiler pattern base dependency -> take + add
346
- for i in range(input_ids.shape[0]):
347
- # cache position [N,1]
348
- positions_idx = cache_position[i]
349
- # offset is set 2 in bart embedding
350
- position_weight = self.embed_positions.weight[2:]
351
- position = position_weight[positions_idx]
352
- batch_hidden = position + inputs_embeds[i]
353
- hidden_all.append(batch_hidden)
354
- hidden_states = torch.stack(hidden_all, dim=0)
355
124
 
356
- hidden_states = self.layernorm_embedding(hidden_states)
125
+ class BartDecoderLayer(Seq2SeqDecoderLayer):
126
+ def __post_init__(self):
127
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
128
+ self.encoder_attn = self._original_mod.encoder_attn
129
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
130
+ self.ff_layer = BartLayerFF(self._original_mod)
357
131
 
358
- # prepare attn_mask
359
- input_shape = input_ids.size()
360
- if self._use_sdpa:
361
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
362
- attention_mask, input_shape, inputs_embeds, cache_position
363
- )
364
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
365
- encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
366
- )
367
- else:
368
- attention_mask = _prepare_4d_causal_attention_mask(
369
- attention_mask, input_shape, inputs_embeds, cache_position
370
- )
371
- encoder_attention_mask = _prepare_4d_attention_mask(
372
- encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
373
- )
374
-
375
- # iterate decoder_layer
376
- next_decoder_cache = ()
377
- for idx, decoder_layer in enumerate(self.layers):
378
- past_key_value = past_key_values[idx]
379
- layer_outputs = _BartDecoderLayer.forward(
380
- decoder_layer,
381
- hidden_states,
382
- attention_mask=attention_mask,
383
- encoder_hidden_states=encoder_hidden_states,
384
- encoder_attention_mask=encoder_attention_mask,
385
- past_key_value=past_key_value,
386
- cache_position=cache_position,
387
- batch_ids=batch_ids,
388
- attn_impl=attn_impl,
389
- )
390
- hidden_states = layer_outputs[0]
391
- next_decoder_cache += (layer_outputs[1],)
392
-
393
- return BaseModelOutputWithPastAndCrossAttentions(
394
- last_hidden_state=hidden_states,
395
- past_key_values=next_decoder_cache,
396
- )
397
-
398
-
399
- class BartDecoderWrapper(torch.nn.Module):
400
- def __init__(self, model: "BartForConditionalGeneration"):
401
- super().__init__()
402
- self.config = model.config
403
- self.decoder = model.get_decoder()
404
- self.num_layers = self.config.decoder_layers
405
- self.lm_head = model.lm_head
406
-
407
- def forward(
408
- self,
409
- input_ids: torch.Tensor,
410
- attention_mask: torch.Tensor,
411
- encoder_attention_mask: torch.Tensor,
412
- cache_position: torch.Tensor,
413
- batch_position: torch.Tensor,
414
- self_kv_cache: torch.Tensor,
415
- cross_kv_cache: torch.Tensor,
416
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
417
- if input_ids.shape[1] == 1:
418
- rbln_batch_position = None
419
- else:
420
- rbln_batch_position = batch_position
421
- # prepare past_key_values
422
- kv_cache = ()
423
- for i in range(0, self.num_layers * 2, 2):
424
- kv_cache = kv_cache + (
425
- (
426
- self_kv_cache[i],
427
- self_kv_cache[i + 1],
428
- cross_kv_cache[i],
429
- cross_kv_cache[i + 1],
430
- ),
431
- )
432
- # decode
433
- decoder_outputs = _BartDecoder.forward(
434
- self.decoder,
435
- input_ids=input_ids,
436
- attention_mask=attention_mask,
437
- encoder_attention_mask=encoder_attention_mask,
438
- cache_position=cache_position,
439
- past_key_values=kv_cache,
440
- encoder_hidden_states=torch.tensor([1]),
441
- attn_impl=self.config._attn_implementation,
442
- batch_ids=rbln_batch_position,
443
- )
444
- sequence_output = decoder_outputs[0]
445
- lm_logits = self.lm_head(sequence_output)
446
-
447
- # get self_kv_cache from ouputs
448
- past_key_values = decoder_outputs[1]
449
- self_kv_cache = []
450
- for i in range(self.num_layers):
451
- self_kv_cache.append(past_key_values[i][0])
452
- self_kv_cache.append(past_key_values[i][1])
453
- self_kv_cache = torch.stack(self_kv_cache, dim=0)
454
-
455
- # return batch_position to keep it as a variable within the graph
456
- return lm_logits, self_kv_cache, batch_position
457
-
458
-
459
- class BartEncoderWrapper(torch.nn.Module):
460
- def __init__(self, model):
461
- super().__init__()
462
- self.model = model
463
- self.config = model.config
464
- self.decoder = model.get_decoder()
465
- self.encoder = model.get_encoder()
466
- self.num_layers = self.config.encoder_layers
467
- self.decoder_max_length = self.config.max_position_embeddings
468
- self.encoder_max_length = self.config.max_position_embeddings
469
- self.num_heads = self.config.decoder_attention_heads
470
- self.d_kv = self.config.d_model // self.num_heads
471
-
472
- def forward(
473
- self,
474
- input_ids: torch.LongTensor,
475
- attention_mask: torch.LongTensor,
476
- cross_key_value: torch.Tensor = None,
477
- batch_idx: torch.Tensor = None,
478
- ) -> Tuple[torch.Tensor]:
479
- # 1. run encoder
480
- encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
481
- last_hidden_states = encoder_outputs[0]
482
-
483
- # 2. run dummy decoder to get pre-calculated cross-key_values for generation
484
- dummy_past_key_value = []
485
- for _ in range(self.num_layers):
486
- pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
487
- pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
488
- pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
489
- pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
490
- layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
491
- dummy_past_key_value.append(layer_pkv)
492
-
493
- decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
494
- decoder_attention_mask[:, :1] = 1
495
-
496
- decoder_outputs = _BartDecoder.forward(
497
- self.decoder,
498
- input_ids=torch.zeros((1, 1), dtype=torch.int64),
499
- attention_mask=decoder_attention_mask,
500
- encoder_attention_mask=attention_mask,
501
- cache_position=torch.tensor(0, dtype=torch.int32),
502
- encoder_hidden_states=last_hidden_states,
503
- past_key_values=dummy_past_key_value,
504
- batch_ids=torch.tensor(0, dtype=torch.int32),
505
- attn_impl=self.config._attn_implementation,
506
- )
507
- first_past_kv = decoder_outputs[1]
508
-
509
- encoder_kv = []
510
- for i in range(self.model.config.decoder_layers):
511
- encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
512
- encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
513
- encoder_kv = torch.cat(encoder_kv, dim=0)
514
-
515
- cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
516
-
517
- return cross_key_value
132
+ def pre_self_attn_layer_norm(self, hidden_states):
133
+ return hidden_states
134
+
135
+ def post_self_attn_layer_norm(self, hidden_states):
136
+ return self.self_attn_layer_norm(hidden_states)
137
+
138
+ def pre_cross_attn_layer_norm(self, hidden_states):
139
+ return hidden_states
140
+
141
+ def post_cross_attn_layer_norm(self, hidden_states):
142
+ return self.encoder_attn_layer_norm(hidden_states)
143
+
144
+
145
+ class BartSelfAttention(Seq2SeqSelfAttention):
146
+ def __post_init__(self):
147
+ self.q_proj = self._original_mod.q_proj
148
+ self.k_proj = self._original_mod.k_proj
149
+ self.v_proj = self._original_mod.v_proj
150
+ self.out_proj = self._original_mod.out_proj
151
+ self.num_heads = self._original_mod.num_heads
152
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
153
+ self.scaling = self.head_dim**-0.5
154
+ self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
155
+
156
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
157
+ query_states = self.q_proj(hidden_states) * self.scaling
158
+ key_states = self.k_proj(hidden_states)
159
+ value_states = self.v_proj(hidden_states)
160
+ return query_states, key_states, value_states
@@ -24,7 +24,7 @@
24
24
  import inspect
25
25
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
26
 
27
- from transformers import BartForConditionalGeneration, PretrainedConfig
27
+ from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
28
 
29
29
  from ....modeling import RBLNModel
30
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -93,11 +93,12 @@ class RBLNBartModel(RBLNModel):
93
93
  for model_input_name in rbln_model_input_names
94
94
  ]
95
95
 
96
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
96
+ enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
97
+ dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
97
98
 
98
99
  rbln_config = RBLNConfig(
99
100
  rbln_cls=cls.__name__,
100
- compile_cfgs=[rbln_compile_config],
101
+ compile_cfgs=[enc_compile_config, dec_compile_config],
101
102
  rbln_kwargs=rbln_kwargs,
102
103
  )
103
104
 
@@ -108,7 +109,10 @@ class RBLNBartModel(RBLNModel):
108
109
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
109
110
  @classmethod
110
111
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
111
- return BartWrapper(model)
112
+ enc_max_seq_len = (
113
+ rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
114
+ )
115
+ return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
112
116
 
113
117
  def __getattr__(self, __name: str) -> Any:
114
118
  def redirect(func):
@@ -34,9 +34,9 @@ from transformers import (
34
34
  from transformers.modeling_outputs import BaseModelOutputWithPooling
35
35
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput
36
36
 
37
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
37
38
  from ....modeling import RBLNModel
38
39
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
39
- from ....modeling_diffusers import RBLNDiffusionMixin
40
40
 
41
41
 
42
42
  logger = logging.getLogger(__name__)
@@ -21,11 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .decoderonly_architecture import (
25
- DecoderOnlyWrapper,
26
- RotaryEmbedding,
27
- apply_rotary_pos_emb,
28
- rotate_half,
29
- slice_and_unsqueeze_cos_sin,
30
- )
31
24
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM