optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -21,497 +21,140 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple
24
+ from typing import Tuple
25
25
 
26
26
  import torch
27
27
  from torch import nn
28
28
  from transformers.modeling_attn_mask_utils import (
29
29
  _prepare_4d_attention_mask,
30
- _prepare_4d_attention_mask_for_sdpa,
31
- _prepare_4d_causal_attention_mask,
32
- _prepare_4d_causal_attention_mask_for_sdpa,
33
- )
34
- from transformers.modeling_outputs import (
35
- BaseModelOutputWithPastAndCrossAttentions,
36
- )
37
- from transformers.models.bart.modeling_bart import (
38
- BartAttention,
39
- BartDecoder,
40
- BartDecoderLayer,
41
- BartForConditionalGeneration,
42
- BartSdpaAttention,
43
30
  )
44
31
  from transformers.utils import logging
45
32
 
33
+ from ..seq2seq.seq2seq_architecture import (
34
+ Seq2SeqDecoder,
35
+ Seq2SeqDecoderLayer,
36
+ Seq2SeqDecoderWrapper,
37
+ Seq2SeqEncoderWrapper,
38
+ Seq2SeqForConditionalGeneration,
39
+ Seq2SeqSelfAttention,
40
+ )
41
+
46
42
 
47
43
  logger = logging.get_logger(__name__)
48
44
 
49
45
 
50
46
  class BartWrapper:
51
- def __init__(self, model):
52
- self.encoder = BartEncoderWrapper(model)
47
+ def __init__(self, model: nn.Module, enc_max_seq_len: int):
48
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
53
49
  self.decoder = BartDecoderWrapper(model)
54
50
 
55
51
 
56
- class _BartAttention(BartAttention):
57
- def forward(
58
- self,
59
- hidden_states: torch.Tensor,
60
- past_key_value: Tuple[torch.Tensor],
61
- attention_mask: torch.Tensor,
62
- cache_position: torch.Tensor,
63
- batch_index: torch.Tensor,
64
- key_value_states: Optional[torch.Tensor] = None,
65
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
66
- bsz, tgt_len, _ = hidden_states.size()
67
- is_cross_attention = key_value_states is not None
52
+ class BartDecoderWrapper(Seq2SeqDecoderWrapper):
53
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
54
+ new_layers = []
55
+ for layer in model.get_decoder().layers:
56
+ self_attn = BartSelfAttention(layer.self_attn)
57
+ new_layers.append(BartDecoderLayer(layer, self_attn))
68
58
 
69
- query_states = self.q_proj(hidden_states) * self.scaling
59
+ decoder_model = BartDecoder(model.get_decoder(), new_layers)
60
+ new_model = BartForConditionalGeneration(model, decoder_model)
70
61
 
71
- if is_cross_attention:
72
- is_dummy_decoder = len(key_value_states.shape) > 1
73
- if is_dummy_decoder:
74
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
75
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
76
- else:
77
- key_states = past_key_value[0]
78
- value_states = past_key_value[1]
79
- else:
80
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
81
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
82
-
83
- if cache_position.dim() > 0:
84
- proj_shape = (bsz, self.num_heads, -1, self.head_dim)
85
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
86
- key_states = key_states.reshape(*proj_shape)
87
- value_states = value_states.reshape(*proj_shape)
88
-
89
- all_key_states = []
90
- all_value_states = []
91
- all_attn_output = []
92
- for b in range(bsz):
93
- batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
94
- batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
95
- batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
96
- batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
97
- if not is_cross_attention:
98
- batch_key_states = (
99
- past_key_value[0][b]
100
- .unsqueeze(0)
101
- .unsqueeze(2)
102
- .slice_scatter(
103
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
104
- )
105
- )
106
- batch_value_states = (
107
- past_key_value[1][b]
108
- .unsqueeze(0)
109
- .unsqueeze(2)
110
- .slice_scatter(
111
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
112
- )
113
- )
114
- attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
115
- attn_weights = attn_weights + batch_attention_mask
116
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
117
-
118
- attn_output = torch.matmul(attn_weights, batch_value_states)
119
- attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
120
- attn_output = attn_output.transpose(1, 2)
121
- attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
122
- all_key_states.append(batch_key_states)
123
- all_value_states.append(batch_value_states)
124
- all_attn_output.append(attn_output)
125
- key_states = torch.cat(all_key_states, dim=0).squeeze(2)
126
- value_states = torch.cat(all_value_states, dim=0).squeeze(2)
127
- attn_output = torch.cat(all_attn_output, dim=0)
62
+ return new_model
128
63
 
129
- else:
130
- if batch_index is None or batch_index == -1:
131
- batch_index = 0
132
-
133
- if not is_cross_attention:
134
- key_states = past_key_value[0].slice_scatter(
135
- key_states, dim=2, start=cache_position, end=cache_position + 1
136
- )
137
- value_states = past_key_value[1].slice_scatter(
138
- value_states, dim=2, start=cache_position, end=cache_position + 1
139
- )
140
-
141
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
142
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
143
- key_states = key_states.reshape(*proj_shape)
144
- value_states = value_states.reshape(*proj_shape)
145
-
146
- src_len = key_states.size(1)
147
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
148
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
149
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
150
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
151
-
152
- attn_output = torch.bmm(attn_weights, value_states)
153
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
154
- attn_output = attn_output.transpose(1, 2)
155
- key_states = key_states.unsqueeze(0)
156
- value_states = value_states.unsqueeze(0)
157
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
158
-
159
- attn_output = self.out_proj(attn_output)
160
-
161
- present_key_value = (key_states, value_states)
162
-
163
- return attn_output, present_key_value
164
-
165
-
166
- class _BartSdpaAttention(BartSdpaAttention):
167
- def forward(
168
- self,
169
- hidden_states: torch.Tensor,
170
- past_key_value: Tuple[torch.Tensor],
171
- attention_mask: torch.Tensor,
172
- cache_position: torch.Tensor,
173
- batch_index: torch.Tensor,
174
- key_value_states: Optional[torch.Tensor] = None,
175
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
176
- bsz, tgt_len, _ = hidden_states.size()
177
- is_cross_attention = key_value_states is not None
178
-
179
- query_states = self.q_proj(hidden_states)
180
-
181
- if is_cross_attention:
182
- is_dummy_decoder = len(key_value_states.shape) > 1
183
- if is_dummy_decoder:
184
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
185
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
186
- else:
187
- key_states = past_key_value[0]
188
- value_states = past_key_value[1]
189
- else:
190
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
191
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
192
-
193
- query_states = self._shape(query_states, tgt_len, bsz)
194
-
195
- if (batch_index is None or batch_index == -1) and bsz > 1:
196
- all_key_states = []
197
- all_value_states = []
198
- all_attn_output = []
199
-
200
- for b in range(bsz):
201
- batch_query_states = query_states[b].unsqueeze(0)
202
- batch_attention_mask = attention_mask[b].unsqueeze(0)
203
- batch_key_states = key_states[b].unsqueeze(0)
204
- batch_value_states = value_states[b].unsqueeze(0)
205
-
206
- if not is_cross_attention:
207
- batch_key_states = (
208
- past_key_value[0][b]
209
- .unsqueeze(0)
210
- .slice_scatter(
211
- batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
212
- )
213
- )
214
- batch_value_states = (
215
- past_key_value[1][b]
216
- .unsqueeze(0)
217
- .slice_scatter(
218
- batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
219
- )
220
- )
221
-
222
- attn_output = torch.nn.functional.scaled_dot_product_attention(
223
- batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
224
- )
225
- attn_output = attn_output.transpose(1, 2)
226
- attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
227
- all_key_states.append(batch_key_states)
228
- all_value_states.append(batch_value_states)
229
- all_attn_output.append(attn_output)
230
-
231
- key_states = torch.cat(all_key_states, dim=0)
232
- value_states = torch.cat(all_value_states, dim=0)
233
- attn_output = torch.cat(all_attn_output, dim=0)
234
64
 
65
+ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
66
+ has_rescaling = False
67
+
68
+ def __post_init__(self):
69
+ self.scaling = self.config.d_model**-0.5
70
+
71
+
72
+ class BartDecoder(Seq2SeqDecoder):
73
+ has_pos_emb = True
74
+
75
+ def __post_init__(self):
76
+ self.embed_positions = self._original_mod.embed_positions
77
+ self.layernorm_embedding = self._original_mod.layernorm_embedding
78
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
79
+
80
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
81
+ attention_mask = attention_mask[:, None, None, :]
82
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
83
+
84
+ return attention_mask, encoder_attention_mask
85
+
86
+ def apply_position_embedding(self, inputs_embeds, cache_position):
87
+ hidden_all = []
88
+ for i in range(inputs_embeds.shape[0]):
89
+ positions_idx = cache_position[i]
90
+ position_weight = self.embed_positions.weight[2:]
91
+ position = position_weight[positions_idx]
92
+ batch_hidden = position + inputs_embeds[i]
93
+ hidden_all.append(batch_hidden)
94
+ hidden_states = torch.stack(hidden_all, dim=0)
95
+
96
+ hidden_states = self.layernorm_embedding(hidden_states)
97
+
98
+ return hidden_states
99
+
100
+ def get_embedding(self):
101
+ if self.embed_scale is not None:
102
+ return lambda x: self.embed_tokens(x) * self.embed_scale
235
103
  else:
236
- if batch_index is None or batch_index == -1:
237
- batch_index = 0
238
-
239
- if not is_cross_attention:
240
- key_states = past_key_value[0].slice_scatter(
241
- key_states, dim=2, start=cache_position, end=cache_position + 1
242
- )
243
- value_states = past_key_value[1].slice_scatter(
244
- value_states, dim=2, start=cache_position, end=cache_position + 1
245
- )
246
-
247
- # need 4d shape (input tensors) for scaled_dot_product_attention
248
- attn_output = torch.nn.functional.scaled_dot_product_attention(
249
- query_states,
250
- key_states,
251
- value_states,
252
- attn_mask=attention_mask,
253
- )
254
- attn_output = attn_output.transpose(1, 2)
255
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
256
-
257
- attn_output = self.out_proj(attn_output)
258
-
259
- present_key_value = (key_states, value_states)
260
-
261
- return attn_output, present_key_value
262
-
263
-
264
- ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
265
-
266
-
267
- class _BartDecoderLayer(BartDecoderLayer):
268
- def forward(
269
- self,
270
- hidden_states: torch.Tensor,
271
- attention_mask: torch.Tensor,
272
- encoder_attention_mask: torch.Tensor,
273
- encoder_hidden_states: torch.Tensor,
274
- past_key_value: Tuple[torch.Tensor],
275
- cache_position: torch.Tensor,
276
- batch_ids: torch.Tensor,
277
- attn_impl: str = "eager",
278
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
279
- # Self Attention Block
280
- residual = hidden_states
281
- self_attn_past_key_value = past_key_value[:2]
282
-
283
- hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
284
- self.self_attn,
285
- hidden_states=hidden_states,
286
- past_key_value=self_attn_past_key_value,
287
- attention_mask=attention_mask,
288
- cache_position=cache_position,
289
- batch_index=batch_ids,
290
- )
291
- hidden_states = residual + hidden_states
292
- hidden_states = self.self_attn_layer_norm(hidden_states)
104
+ return self.embed_tokens
293
105
 
294
- # Cross-Attention Block
295
- residual = hidden_states
296
- cross_attn_past_key_value = past_key_value[-2:]
297
-
298
- hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
299
- self.encoder_attn,
300
- hidden_states=hidden_states,
301
- key_value_states=encoder_hidden_states,
302
- past_key_value=cross_attn_past_key_value,
303
- attention_mask=encoder_attention_mask,
304
- cache_position=cache_position,
305
- batch_index=batch_ids,
306
- )
307
- hidden_states = residual + hidden_states
308
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
309
- present_key_value = present_key_value + cross_attn_present_key_value
310
106
 
311
- # Fully Connected Block
107
+ class BartLayerFF(nn.Module):
108
+ def __init__(self, decoder_layer):
109
+ super().__init__()
110
+ self.fc1 = decoder_layer.fc1
111
+ self.fc2 = decoder_layer.fc2
112
+ self.activation_fn = decoder_layer.activation_fn
113
+ self.layer_norm = decoder_layer.final_layer_norm
114
+
115
+ def forward(self, hidden_states):
116
+ # Residual Connection
312
117
  residual = hidden_states
313
118
  hidden_states = self.activation_fn(self.fc1(hidden_states))
314
119
  hidden_states = self.fc2(hidden_states)
315
120
  hidden_states = residual + hidden_states
316
- hidden_states = self.final_layer_norm(hidden_states)
317
-
318
- return hidden_states, present_key_value
319
-
320
-
321
- class _BartDecoder(BartDecoder):
322
- def forward(
323
- self,
324
- input_ids: torch.Tensor,
325
- attention_mask: torch.Tensor,
326
- encoder_attention_mask: torch.Tensor,
327
- encoder_hidden_states: torch.Tensor,
328
- past_key_values: torch.Tensor,
329
- cache_position: torch.Tensor,
330
- batch_ids: torch.Tensor,
331
- attn_impl: str = "eager",
332
- ):
333
- # embedding
334
- if hasattr(self, "embed_scale"):
335
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
336
- else:
337
- inputs_embeds = self.embed_tokens(input_ids)
121
+ hidden_states = self.layer_norm(hidden_states)
122
+ return hidden_states
338
123
 
339
- if cache_position.dim() == 0:
340
- positions_idx = cache_position + self.embed_positions.offset
341
- positions = self.embed_positions.weight[positions_idx]
342
- hidden_states = inputs_embeds + positions
343
- else:
344
- hidden_all = []
345
- # compiler pattern base dependency -> take + add
346
- for i in range(input_ids.shape[0]):
347
- # cache position [N,1]
348
- positions_idx = cache_position[i]
349
- # offset is set 2 in bart embedding
350
- position_weight = self.embed_positions.weight[2:]
351
- position = position_weight[positions_idx]
352
- batch_hidden = position + inputs_embeds[i]
353
- hidden_all.append(batch_hidden)
354
- hidden_states = torch.stack(hidden_all, dim=0)
355
124
 
356
- hidden_states = self.layernorm_embedding(hidden_states)
125
+ class BartDecoderLayer(Seq2SeqDecoderLayer):
126
+ def __post_init__(self):
127
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
128
+ self.encoder_attn = self._original_mod.encoder_attn
129
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
130
+ self.ff_layer = BartLayerFF(self._original_mod)
357
131
 
358
- # prepare attn_mask
359
- input_shape = input_ids.size()
360
- if self._use_sdpa:
361
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
362
- attention_mask, input_shape, inputs_embeds, cache_position
363
- )
364
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
365
- encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
366
- )
367
- else:
368
- attention_mask = _prepare_4d_causal_attention_mask(
369
- attention_mask, input_shape, inputs_embeds, cache_position
370
- )
371
- encoder_attention_mask = _prepare_4d_attention_mask(
372
- encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
373
- )
374
-
375
- # iterate decoder_layer
376
- next_decoder_cache = ()
377
- for idx, decoder_layer in enumerate(self.layers):
378
- past_key_value = past_key_values[idx]
379
- layer_outputs = _BartDecoderLayer.forward(
380
- decoder_layer,
381
- hidden_states,
382
- attention_mask=attention_mask,
383
- encoder_hidden_states=encoder_hidden_states,
384
- encoder_attention_mask=encoder_attention_mask,
385
- past_key_value=past_key_value,
386
- cache_position=cache_position,
387
- batch_ids=batch_ids,
388
- attn_impl=attn_impl,
389
- )
390
- hidden_states = layer_outputs[0]
391
- next_decoder_cache += (layer_outputs[1],)
392
-
393
- return BaseModelOutputWithPastAndCrossAttentions(
394
- last_hidden_state=hidden_states,
395
- past_key_values=next_decoder_cache,
396
- )
397
-
398
-
399
- class BartDecoderWrapper(torch.nn.Module):
400
- def __init__(self, model: "BartForConditionalGeneration"):
401
- super().__init__()
402
- self.config = model.config
403
- self.decoder = model.get_decoder()
404
- self.num_layers = self.config.decoder_layers
405
- self.lm_head = model.lm_head
406
-
407
- def forward(
408
- self,
409
- input_ids: torch.Tensor,
410
- attention_mask: torch.Tensor,
411
- encoder_attention_mask: torch.Tensor,
412
- cache_position: torch.Tensor,
413
- batch_position: torch.Tensor,
414
- self_kv_cache: torch.Tensor,
415
- cross_kv_cache: torch.Tensor,
416
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
417
- if input_ids.shape[1] == 1:
418
- rbln_batch_position = None
419
- else:
420
- rbln_batch_position = batch_position
421
- # prepare past_key_values
422
- kv_cache = ()
423
- for i in range(0, self.num_layers * 2, 2):
424
- kv_cache = kv_cache + (
425
- (
426
- self_kv_cache[i],
427
- self_kv_cache[i + 1],
428
- cross_kv_cache[i],
429
- cross_kv_cache[i + 1],
430
- ),
431
- )
432
- # decode
433
- decoder_outputs = _BartDecoder.forward(
434
- self.decoder,
435
- input_ids=input_ids,
436
- attention_mask=attention_mask,
437
- encoder_attention_mask=encoder_attention_mask,
438
- cache_position=cache_position,
439
- past_key_values=kv_cache,
440
- encoder_hidden_states=torch.tensor([1]),
441
- attn_impl=self.config._attn_implementation,
442
- batch_ids=rbln_batch_position,
443
- )
444
- sequence_output = decoder_outputs[0]
445
- lm_logits = self.lm_head(sequence_output)
446
-
447
- # get self_kv_cache from ouputs
448
- past_key_values = decoder_outputs[1]
449
- self_kv_cache = []
450
- for i in range(self.num_layers):
451
- self_kv_cache.append(past_key_values[i][0])
452
- self_kv_cache.append(past_key_values[i][1])
453
- self_kv_cache = torch.stack(self_kv_cache, dim=0)
454
-
455
- # return batch_position to keep it as a variable within the graph
456
- return lm_logits, self_kv_cache, batch_position
457
-
458
-
459
- class BartEncoderWrapper(torch.nn.Module):
460
- def __init__(self, model):
461
- super().__init__()
462
- self.model = model
463
- self.config = model.config
464
- self.decoder = model.get_decoder()
465
- self.encoder = model.get_encoder()
466
- self.num_layers = self.config.encoder_layers
467
- self.decoder_max_length = self.config.max_position_embeddings
468
- self.encoder_max_length = self.config.max_position_embeddings
469
- self.num_heads = self.config.decoder_attention_heads
470
- self.d_kv = self.config.d_model // self.num_heads
471
-
472
- def forward(
473
- self,
474
- input_ids: torch.LongTensor,
475
- attention_mask: torch.LongTensor,
476
- cross_key_value: torch.Tensor = None,
477
- batch_idx: torch.Tensor = None,
478
- ) -> Tuple[torch.Tensor]:
479
- # 1. run encoder
480
- encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
481
- last_hidden_states = encoder_outputs[0]
482
-
483
- # 2. run dummy decoder to get pre-calculated cross-key_values for generation
484
- dummy_past_key_value = []
485
- for _ in range(self.num_layers):
486
- pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
487
- pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
488
- pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
489
- pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
490
- layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
491
- dummy_past_key_value.append(layer_pkv)
492
-
493
- decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
494
- decoder_attention_mask[:, :1] = 1
495
-
496
- decoder_outputs = _BartDecoder.forward(
497
- self.decoder,
498
- input_ids=torch.zeros((1, 1), dtype=torch.int64),
499
- attention_mask=decoder_attention_mask,
500
- encoder_attention_mask=attention_mask,
501
- cache_position=torch.tensor(0, dtype=torch.int32),
502
- encoder_hidden_states=last_hidden_states,
503
- past_key_values=dummy_past_key_value,
504
- batch_ids=torch.tensor(0, dtype=torch.int32),
505
- attn_impl=self.config._attn_implementation,
506
- )
507
- first_past_kv = decoder_outputs[1]
508
-
509
- encoder_kv = []
510
- for i in range(self.model.config.decoder_layers):
511
- encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
512
- encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
513
- encoder_kv = torch.cat(encoder_kv, dim=0)
514
-
515
- cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
516
-
517
- return cross_key_value
132
+ def pre_self_attn_layer_norm(self, hidden_states):
133
+ return hidden_states
134
+
135
+ def post_self_attn_layer_norm(self, hidden_states):
136
+ return self.self_attn_layer_norm(hidden_states)
137
+
138
+ def pre_cross_attn_layer_norm(self, hidden_states):
139
+ return hidden_states
140
+
141
+ def post_cross_attn_layer_norm(self, hidden_states):
142
+ return self.encoder_attn_layer_norm(hidden_states)
143
+
144
+
145
+ class BartSelfAttention(Seq2SeqSelfAttention):
146
+ def __post_init__(self):
147
+ self.q_proj = self._original_mod.q_proj
148
+ self.k_proj = self._original_mod.k_proj
149
+ self.v_proj = self._original_mod.v_proj
150
+ self.out_proj = self._original_mod.out_proj
151
+ self.num_heads = self._original_mod.num_heads
152
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
153
+ self.scaling = self.head_dim**-0.5
154
+ self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
155
+
156
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
157
+ query_states = self.q_proj(hidden_states) * self.scaling
158
+ key_states = self.k_proj(hidden_states)
159
+ value_states = self.v_proj(hidden_states)
160
+ return query_states, key_states, value_states
@@ -24,9 +24,9 @@
24
24
  import inspect
25
25
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
26
 
27
- from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
27
+ from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
28
 
29
- from ....modeling_base import RBLNModel
29
+ from ....modeling import RBLNModel
30
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -41,9 +41,6 @@ if TYPE_CHECKING:
41
41
 
42
42
 
43
43
  class RBLNBartModel(RBLNModel):
44
- original_model_class = BartModel
45
- original_config_class = BartConfig
46
-
47
44
  @classmethod
48
45
  def _get_rbln_config(
49
46
  cls,
@@ -82,7 +79,7 @@ class RBLNBartModel(RBLNModel):
82
79
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
83
80
  rbln_model_input_names = cls.rbln_model_input_names
84
81
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
85
- input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
82
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
86
83
  raise ValueError(
87
84
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
88
85
  f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
@@ -96,11 +93,12 @@ class RBLNBartModel(RBLNModel):
96
93
  for model_input_name in rbln_model_input_names
97
94
  ]
98
95
 
99
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
96
+ enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
97
+ dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
100
98
 
101
99
  rbln_config = RBLNConfig(
102
100
  rbln_cls=cls.__name__,
103
- compile_cfgs=[rbln_compile_config],
101
+ compile_cfgs=[enc_compile_config, dec_compile_config],
104
102
  rbln_kwargs=rbln_kwargs,
105
103
  )
106
104
 
@@ -111,7 +109,10 @@ class RBLNBartModel(RBLNModel):
111
109
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
112
110
  @classmethod
113
111
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
114
- return BartWrapper(model)
112
+ enc_max_seq_len = (
113
+ rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
114
+ )
115
+ return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
115
116
 
116
117
  def __getattr__(self, __name: str) -> Any:
117
118
  def redirect(func):