optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -21,497 +21,140 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
from torch import nn
|
28
28
|
from transformers.modeling_attn_mask_utils import (
|
29
29
|
_prepare_4d_attention_mask,
|
30
|
-
_prepare_4d_attention_mask_for_sdpa,
|
31
|
-
_prepare_4d_causal_attention_mask,
|
32
|
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
33
|
-
)
|
34
|
-
from transformers.modeling_outputs import (
|
35
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
36
|
-
)
|
37
|
-
from transformers.models.bart.modeling_bart import (
|
38
|
-
BartAttention,
|
39
|
-
BartDecoder,
|
40
|
-
BartDecoderLayer,
|
41
|
-
BartForConditionalGeneration,
|
42
|
-
BartSdpaAttention,
|
43
30
|
)
|
44
31
|
from transformers.utils import logging
|
45
32
|
|
33
|
+
from ..seq2seq.seq2seq_architecture import (
|
34
|
+
Seq2SeqDecoder,
|
35
|
+
Seq2SeqDecoderLayer,
|
36
|
+
Seq2SeqDecoderWrapper,
|
37
|
+
Seq2SeqEncoderWrapper,
|
38
|
+
Seq2SeqForConditionalGeneration,
|
39
|
+
Seq2SeqSelfAttention,
|
40
|
+
)
|
41
|
+
|
46
42
|
|
47
43
|
logger = logging.get_logger(__name__)
|
48
44
|
|
49
45
|
|
50
46
|
class BartWrapper:
|
51
|
-
def __init__(self, model):
|
52
|
-
self.encoder =
|
47
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
48
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
53
49
|
self.decoder = BartDecoderWrapper(model)
|
54
50
|
|
55
51
|
|
56
|
-
class
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
cache_position: torch.Tensor,
|
63
|
-
batch_index: torch.Tensor,
|
64
|
-
key_value_states: Optional[torch.Tensor] = None,
|
65
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
66
|
-
bsz, tgt_len, _ = hidden_states.size()
|
67
|
-
is_cross_attention = key_value_states is not None
|
52
|
+
class BartDecoderWrapper(Seq2SeqDecoderWrapper):
|
53
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
54
|
+
new_layers = []
|
55
|
+
for layer in model.get_decoder().layers:
|
56
|
+
self_attn = BartSelfAttention(layer.self_attn)
|
57
|
+
new_layers.append(BartDecoderLayer(layer, self_attn))
|
68
58
|
|
69
|
-
|
59
|
+
decoder_model = BartDecoder(model.get_decoder(), new_layers)
|
60
|
+
new_model = BartForConditionalGeneration(model, decoder_model)
|
70
61
|
|
71
|
-
|
72
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
73
|
-
if is_dummy_decoder:
|
74
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
75
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
76
|
-
else:
|
77
|
-
key_states = past_key_value[0]
|
78
|
-
value_states = past_key_value[1]
|
79
|
-
else:
|
80
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
81
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
82
|
-
|
83
|
-
if cache_position.dim() > 0:
|
84
|
-
proj_shape = (bsz, self.num_heads, -1, self.head_dim)
|
85
|
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
86
|
-
key_states = key_states.reshape(*proj_shape)
|
87
|
-
value_states = value_states.reshape(*proj_shape)
|
88
|
-
|
89
|
-
all_key_states = []
|
90
|
-
all_value_states = []
|
91
|
-
all_attn_output = []
|
92
|
-
for b in range(bsz):
|
93
|
-
batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
|
94
|
-
batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
|
95
|
-
batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
|
96
|
-
batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
|
97
|
-
if not is_cross_attention:
|
98
|
-
batch_key_states = (
|
99
|
-
past_key_value[0][b]
|
100
|
-
.unsqueeze(0)
|
101
|
-
.unsqueeze(2)
|
102
|
-
.slice_scatter(
|
103
|
-
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
104
|
-
)
|
105
|
-
)
|
106
|
-
batch_value_states = (
|
107
|
-
past_key_value[1][b]
|
108
|
-
.unsqueeze(0)
|
109
|
-
.unsqueeze(2)
|
110
|
-
.slice_scatter(
|
111
|
-
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
112
|
-
)
|
113
|
-
)
|
114
|
-
attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
|
115
|
-
attn_weights = attn_weights + batch_attention_mask
|
116
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
117
|
-
|
118
|
-
attn_output = torch.matmul(attn_weights, batch_value_states)
|
119
|
-
attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
|
120
|
-
attn_output = attn_output.transpose(1, 2)
|
121
|
-
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
122
|
-
all_key_states.append(batch_key_states)
|
123
|
-
all_value_states.append(batch_value_states)
|
124
|
-
all_attn_output.append(attn_output)
|
125
|
-
key_states = torch.cat(all_key_states, dim=0).squeeze(2)
|
126
|
-
value_states = torch.cat(all_value_states, dim=0).squeeze(2)
|
127
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
62
|
+
return new_model
|
128
63
|
|
129
|
-
else:
|
130
|
-
if batch_index is None or batch_index == -1:
|
131
|
-
batch_index = 0
|
132
|
-
|
133
|
-
if not is_cross_attention:
|
134
|
-
key_states = past_key_value[0].slice_scatter(
|
135
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
136
|
-
)
|
137
|
-
value_states = past_key_value[1].slice_scatter(
|
138
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
139
|
-
)
|
140
|
-
|
141
|
-
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
142
|
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
143
|
-
key_states = key_states.reshape(*proj_shape)
|
144
|
-
value_states = value_states.reshape(*proj_shape)
|
145
|
-
|
146
|
-
src_len = key_states.size(1)
|
147
|
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
148
|
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
149
|
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
150
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
151
|
-
|
152
|
-
attn_output = torch.bmm(attn_weights, value_states)
|
153
|
-
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
154
|
-
attn_output = attn_output.transpose(1, 2)
|
155
|
-
key_states = key_states.unsqueeze(0)
|
156
|
-
value_states = value_states.unsqueeze(0)
|
157
|
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
158
|
-
|
159
|
-
attn_output = self.out_proj(attn_output)
|
160
|
-
|
161
|
-
present_key_value = (key_states, value_states)
|
162
|
-
|
163
|
-
return attn_output, present_key_value
|
164
|
-
|
165
|
-
|
166
|
-
class _BartSdpaAttention(BartSdpaAttention):
|
167
|
-
def forward(
|
168
|
-
self,
|
169
|
-
hidden_states: torch.Tensor,
|
170
|
-
past_key_value: Tuple[torch.Tensor],
|
171
|
-
attention_mask: torch.Tensor,
|
172
|
-
cache_position: torch.Tensor,
|
173
|
-
batch_index: torch.Tensor,
|
174
|
-
key_value_states: Optional[torch.Tensor] = None,
|
175
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
176
|
-
bsz, tgt_len, _ = hidden_states.size()
|
177
|
-
is_cross_attention = key_value_states is not None
|
178
|
-
|
179
|
-
query_states = self.q_proj(hidden_states)
|
180
|
-
|
181
|
-
if is_cross_attention:
|
182
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
183
|
-
if is_dummy_decoder:
|
184
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
185
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
186
|
-
else:
|
187
|
-
key_states = past_key_value[0]
|
188
|
-
value_states = past_key_value[1]
|
189
|
-
else:
|
190
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
191
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
192
|
-
|
193
|
-
query_states = self._shape(query_states, tgt_len, bsz)
|
194
|
-
|
195
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
196
|
-
all_key_states = []
|
197
|
-
all_value_states = []
|
198
|
-
all_attn_output = []
|
199
|
-
|
200
|
-
for b in range(bsz):
|
201
|
-
batch_query_states = query_states[b].unsqueeze(0)
|
202
|
-
batch_attention_mask = attention_mask[b].unsqueeze(0)
|
203
|
-
batch_key_states = key_states[b].unsqueeze(0)
|
204
|
-
batch_value_states = value_states[b].unsqueeze(0)
|
205
|
-
|
206
|
-
if not is_cross_attention:
|
207
|
-
batch_key_states = (
|
208
|
-
past_key_value[0][b]
|
209
|
-
.unsqueeze(0)
|
210
|
-
.slice_scatter(
|
211
|
-
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
212
|
-
)
|
213
|
-
)
|
214
|
-
batch_value_states = (
|
215
|
-
past_key_value[1][b]
|
216
|
-
.unsqueeze(0)
|
217
|
-
.slice_scatter(
|
218
|
-
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
219
|
-
)
|
220
|
-
)
|
221
|
-
|
222
|
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
223
|
-
batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
|
224
|
-
)
|
225
|
-
attn_output = attn_output.transpose(1, 2)
|
226
|
-
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
227
|
-
all_key_states.append(batch_key_states)
|
228
|
-
all_value_states.append(batch_value_states)
|
229
|
-
all_attn_output.append(attn_output)
|
230
|
-
|
231
|
-
key_states = torch.cat(all_key_states, dim=0)
|
232
|
-
value_states = torch.cat(all_value_states, dim=0)
|
233
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
234
64
|
|
65
|
+
class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
66
|
+
has_rescaling = False
|
67
|
+
|
68
|
+
def __post_init__(self):
|
69
|
+
self.scaling = self.config.d_model**-0.5
|
70
|
+
|
71
|
+
|
72
|
+
class BartDecoder(Seq2SeqDecoder):
|
73
|
+
has_pos_emb = True
|
74
|
+
|
75
|
+
def __post_init__(self):
|
76
|
+
self.embed_positions = self._original_mod.embed_positions
|
77
|
+
self.layernorm_embedding = self._original_mod.layernorm_embedding
|
78
|
+
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
79
|
+
|
80
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
81
|
+
attention_mask = attention_mask[:, None, None, :]
|
82
|
+
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
83
|
+
|
84
|
+
return attention_mask, encoder_attention_mask
|
85
|
+
|
86
|
+
def apply_position_embedding(self, inputs_embeds, cache_position):
|
87
|
+
hidden_all = []
|
88
|
+
for i in range(inputs_embeds.shape[0]):
|
89
|
+
positions_idx = cache_position[i]
|
90
|
+
position_weight = self.embed_positions.weight[2:]
|
91
|
+
position = position_weight[positions_idx]
|
92
|
+
batch_hidden = position + inputs_embeds[i]
|
93
|
+
hidden_all.append(batch_hidden)
|
94
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
95
|
+
|
96
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
97
|
+
|
98
|
+
return hidden_states
|
99
|
+
|
100
|
+
def get_embedding(self):
|
101
|
+
if self.embed_scale is not None:
|
102
|
+
return lambda x: self.embed_tokens(x) * self.embed_scale
|
235
103
|
else:
|
236
|
-
|
237
|
-
batch_index = 0
|
238
|
-
|
239
|
-
if not is_cross_attention:
|
240
|
-
key_states = past_key_value[0].slice_scatter(
|
241
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
242
|
-
)
|
243
|
-
value_states = past_key_value[1].slice_scatter(
|
244
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
245
|
-
)
|
246
|
-
|
247
|
-
# need 4d shape (input tensors) for scaled_dot_product_attention
|
248
|
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
249
|
-
query_states,
|
250
|
-
key_states,
|
251
|
-
value_states,
|
252
|
-
attn_mask=attention_mask,
|
253
|
-
)
|
254
|
-
attn_output = attn_output.transpose(1, 2)
|
255
|
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
256
|
-
|
257
|
-
attn_output = self.out_proj(attn_output)
|
258
|
-
|
259
|
-
present_key_value = (key_states, value_states)
|
260
|
-
|
261
|
-
return attn_output, present_key_value
|
262
|
-
|
263
|
-
|
264
|
-
ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
|
265
|
-
|
266
|
-
|
267
|
-
class _BartDecoderLayer(BartDecoderLayer):
|
268
|
-
def forward(
|
269
|
-
self,
|
270
|
-
hidden_states: torch.Tensor,
|
271
|
-
attention_mask: torch.Tensor,
|
272
|
-
encoder_attention_mask: torch.Tensor,
|
273
|
-
encoder_hidden_states: torch.Tensor,
|
274
|
-
past_key_value: Tuple[torch.Tensor],
|
275
|
-
cache_position: torch.Tensor,
|
276
|
-
batch_ids: torch.Tensor,
|
277
|
-
attn_impl: str = "eager",
|
278
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
279
|
-
# Self Attention Block
|
280
|
-
residual = hidden_states
|
281
|
-
self_attn_past_key_value = past_key_value[:2]
|
282
|
-
|
283
|
-
hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
284
|
-
self.self_attn,
|
285
|
-
hidden_states=hidden_states,
|
286
|
-
past_key_value=self_attn_past_key_value,
|
287
|
-
attention_mask=attention_mask,
|
288
|
-
cache_position=cache_position,
|
289
|
-
batch_index=batch_ids,
|
290
|
-
)
|
291
|
-
hidden_states = residual + hidden_states
|
292
|
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
104
|
+
return self.embed_tokens
|
293
105
|
|
294
|
-
# Cross-Attention Block
|
295
|
-
residual = hidden_states
|
296
|
-
cross_attn_past_key_value = past_key_value[-2:]
|
297
|
-
|
298
|
-
hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
299
|
-
self.encoder_attn,
|
300
|
-
hidden_states=hidden_states,
|
301
|
-
key_value_states=encoder_hidden_states,
|
302
|
-
past_key_value=cross_attn_past_key_value,
|
303
|
-
attention_mask=encoder_attention_mask,
|
304
|
-
cache_position=cache_position,
|
305
|
-
batch_index=batch_ids,
|
306
|
-
)
|
307
|
-
hidden_states = residual + hidden_states
|
308
|
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
309
|
-
present_key_value = present_key_value + cross_attn_present_key_value
|
310
106
|
|
311
|
-
|
107
|
+
class BartLayerFF(nn.Module):
|
108
|
+
def __init__(self, decoder_layer):
|
109
|
+
super().__init__()
|
110
|
+
self.fc1 = decoder_layer.fc1
|
111
|
+
self.fc2 = decoder_layer.fc2
|
112
|
+
self.activation_fn = decoder_layer.activation_fn
|
113
|
+
self.layer_norm = decoder_layer.final_layer_norm
|
114
|
+
|
115
|
+
def forward(self, hidden_states):
|
116
|
+
# Residual Connection
|
312
117
|
residual = hidden_states
|
313
118
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
314
119
|
hidden_states = self.fc2(hidden_states)
|
315
120
|
hidden_states = residual + hidden_states
|
316
|
-
hidden_states = self.
|
317
|
-
|
318
|
-
return hidden_states, present_key_value
|
319
|
-
|
320
|
-
|
321
|
-
class _BartDecoder(BartDecoder):
|
322
|
-
def forward(
|
323
|
-
self,
|
324
|
-
input_ids: torch.Tensor,
|
325
|
-
attention_mask: torch.Tensor,
|
326
|
-
encoder_attention_mask: torch.Tensor,
|
327
|
-
encoder_hidden_states: torch.Tensor,
|
328
|
-
past_key_values: torch.Tensor,
|
329
|
-
cache_position: torch.Tensor,
|
330
|
-
batch_ids: torch.Tensor,
|
331
|
-
attn_impl: str = "eager",
|
332
|
-
):
|
333
|
-
# embedding
|
334
|
-
if hasattr(self, "embed_scale"):
|
335
|
-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
336
|
-
else:
|
337
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
121
|
+
hidden_states = self.layer_norm(hidden_states)
|
122
|
+
return hidden_states
|
338
123
|
|
339
|
-
if cache_position.dim() == 0:
|
340
|
-
positions_idx = cache_position + self.embed_positions.offset
|
341
|
-
positions = self.embed_positions.weight[positions_idx]
|
342
|
-
hidden_states = inputs_embeds + positions
|
343
|
-
else:
|
344
|
-
hidden_all = []
|
345
|
-
# compiler pattern base dependency -> take + add
|
346
|
-
for i in range(input_ids.shape[0]):
|
347
|
-
# cache position [N,1]
|
348
|
-
positions_idx = cache_position[i]
|
349
|
-
# offset is set 2 in bart embedding
|
350
|
-
position_weight = self.embed_positions.weight[2:]
|
351
|
-
position = position_weight[positions_idx]
|
352
|
-
batch_hidden = position + inputs_embeds[i]
|
353
|
-
hidden_all.append(batch_hidden)
|
354
|
-
hidden_states = torch.stack(hidden_all, dim=0)
|
355
124
|
|
356
|
-
|
125
|
+
class BartDecoderLayer(Seq2SeqDecoderLayer):
|
126
|
+
def __post_init__(self):
|
127
|
+
self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
|
128
|
+
self.encoder_attn = self._original_mod.encoder_attn
|
129
|
+
self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
|
130
|
+
self.ff_layer = BartLayerFF(self._original_mod)
|
357
131
|
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
batch_ids=batch_ids,
|
388
|
-
attn_impl=attn_impl,
|
389
|
-
)
|
390
|
-
hidden_states = layer_outputs[0]
|
391
|
-
next_decoder_cache += (layer_outputs[1],)
|
392
|
-
|
393
|
-
return BaseModelOutputWithPastAndCrossAttentions(
|
394
|
-
last_hidden_state=hidden_states,
|
395
|
-
past_key_values=next_decoder_cache,
|
396
|
-
)
|
397
|
-
|
398
|
-
|
399
|
-
class BartDecoderWrapper(torch.nn.Module):
|
400
|
-
def __init__(self, model: "BartForConditionalGeneration"):
|
401
|
-
super().__init__()
|
402
|
-
self.config = model.config
|
403
|
-
self.decoder = model.get_decoder()
|
404
|
-
self.num_layers = self.config.decoder_layers
|
405
|
-
self.lm_head = model.lm_head
|
406
|
-
|
407
|
-
def forward(
|
408
|
-
self,
|
409
|
-
input_ids: torch.Tensor,
|
410
|
-
attention_mask: torch.Tensor,
|
411
|
-
encoder_attention_mask: torch.Tensor,
|
412
|
-
cache_position: torch.Tensor,
|
413
|
-
batch_position: torch.Tensor,
|
414
|
-
self_kv_cache: torch.Tensor,
|
415
|
-
cross_kv_cache: torch.Tensor,
|
416
|
-
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
417
|
-
if input_ids.shape[1] == 1:
|
418
|
-
rbln_batch_position = None
|
419
|
-
else:
|
420
|
-
rbln_batch_position = batch_position
|
421
|
-
# prepare past_key_values
|
422
|
-
kv_cache = ()
|
423
|
-
for i in range(0, self.num_layers * 2, 2):
|
424
|
-
kv_cache = kv_cache + (
|
425
|
-
(
|
426
|
-
self_kv_cache[i],
|
427
|
-
self_kv_cache[i + 1],
|
428
|
-
cross_kv_cache[i],
|
429
|
-
cross_kv_cache[i + 1],
|
430
|
-
),
|
431
|
-
)
|
432
|
-
# decode
|
433
|
-
decoder_outputs = _BartDecoder.forward(
|
434
|
-
self.decoder,
|
435
|
-
input_ids=input_ids,
|
436
|
-
attention_mask=attention_mask,
|
437
|
-
encoder_attention_mask=encoder_attention_mask,
|
438
|
-
cache_position=cache_position,
|
439
|
-
past_key_values=kv_cache,
|
440
|
-
encoder_hidden_states=torch.tensor([1]),
|
441
|
-
attn_impl=self.config._attn_implementation,
|
442
|
-
batch_ids=rbln_batch_position,
|
443
|
-
)
|
444
|
-
sequence_output = decoder_outputs[0]
|
445
|
-
lm_logits = self.lm_head(sequence_output)
|
446
|
-
|
447
|
-
# get self_kv_cache from ouputs
|
448
|
-
past_key_values = decoder_outputs[1]
|
449
|
-
self_kv_cache = []
|
450
|
-
for i in range(self.num_layers):
|
451
|
-
self_kv_cache.append(past_key_values[i][0])
|
452
|
-
self_kv_cache.append(past_key_values[i][1])
|
453
|
-
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
454
|
-
|
455
|
-
# return batch_position to keep it as a variable within the graph
|
456
|
-
return lm_logits, self_kv_cache, batch_position
|
457
|
-
|
458
|
-
|
459
|
-
class BartEncoderWrapper(torch.nn.Module):
|
460
|
-
def __init__(self, model):
|
461
|
-
super().__init__()
|
462
|
-
self.model = model
|
463
|
-
self.config = model.config
|
464
|
-
self.decoder = model.get_decoder()
|
465
|
-
self.encoder = model.get_encoder()
|
466
|
-
self.num_layers = self.config.encoder_layers
|
467
|
-
self.decoder_max_length = self.config.max_position_embeddings
|
468
|
-
self.encoder_max_length = self.config.max_position_embeddings
|
469
|
-
self.num_heads = self.config.decoder_attention_heads
|
470
|
-
self.d_kv = self.config.d_model // self.num_heads
|
471
|
-
|
472
|
-
def forward(
|
473
|
-
self,
|
474
|
-
input_ids: torch.LongTensor,
|
475
|
-
attention_mask: torch.LongTensor,
|
476
|
-
cross_key_value: torch.Tensor = None,
|
477
|
-
batch_idx: torch.Tensor = None,
|
478
|
-
) -> Tuple[torch.Tensor]:
|
479
|
-
# 1. run encoder
|
480
|
-
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
481
|
-
last_hidden_states = encoder_outputs[0]
|
482
|
-
|
483
|
-
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
484
|
-
dummy_past_key_value = []
|
485
|
-
for _ in range(self.num_layers):
|
486
|
-
pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
487
|
-
pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
488
|
-
pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
489
|
-
pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
490
|
-
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
491
|
-
dummy_past_key_value.append(layer_pkv)
|
492
|
-
|
493
|
-
decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
|
494
|
-
decoder_attention_mask[:, :1] = 1
|
495
|
-
|
496
|
-
decoder_outputs = _BartDecoder.forward(
|
497
|
-
self.decoder,
|
498
|
-
input_ids=torch.zeros((1, 1), dtype=torch.int64),
|
499
|
-
attention_mask=decoder_attention_mask,
|
500
|
-
encoder_attention_mask=attention_mask,
|
501
|
-
cache_position=torch.tensor(0, dtype=torch.int32),
|
502
|
-
encoder_hidden_states=last_hidden_states,
|
503
|
-
past_key_values=dummy_past_key_value,
|
504
|
-
batch_ids=torch.tensor(0, dtype=torch.int32),
|
505
|
-
attn_impl=self.config._attn_implementation,
|
506
|
-
)
|
507
|
-
first_past_kv = decoder_outputs[1]
|
508
|
-
|
509
|
-
encoder_kv = []
|
510
|
-
for i in range(self.model.config.decoder_layers):
|
511
|
-
encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
|
512
|
-
encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
|
513
|
-
encoder_kv = torch.cat(encoder_kv, dim=0)
|
514
|
-
|
515
|
-
cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
|
516
|
-
|
517
|
-
return cross_key_value
|
132
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
133
|
+
return hidden_states
|
134
|
+
|
135
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
136
|
+
return self.self_attn_layer_norm(hidden_states)
|
137
|
+
|
138
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
139
|
+
return hidden_states
|
140
|
+
|
141
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
142
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
143
|
+
|
144
|
+
|
145
|
+
class BartSelfAttention(Seq2SeqSelfAttention):
|
146
|
+
def __post_init__(self):
|
147
|
+
self.q_proj = self._original_mod.q_proj
|
148
|
+
self.k_proj = self._original_mod.k_proj
|
149
|
+
self.v_proj = self._original_mod.v_proj
|
150
|
+
self.out_proj = self._original_mod.out_proj
|
151
|
+
self.num_heads = self._original_mod.num_heads
|
152
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
153
|
+
self.scaling = self.head_dim**-0.5
|
154
|
+
self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
|
155
|
+
|
156
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
157
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
158
|
+
key_states = self.k_proj(hidden_states)
|
159
|
+
value_states = self.v_proj(hidden_states)
|
160
|
+
return query_states, key_states, value_states
|
@@ -24,7 +24,7 @@
|
|
24
24
|
import inspect
|
25
25
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
26
|
|
27
|
-
from transformers import BartForConditionalGeneration, PretrainedConfig
|
27
|
+
from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
28
28
|
|
29
29
|
from ....modeling import RBLNModel
|
30
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
@@ -93,11 +93,12 @@ class RBLNBartModel(RBLNModel):
|
|
93
93
|
for model_input_name in rbln_model_input_names
|
94
94
|
]
|
95
95
|
|
96
|
-
|
96
|
+
enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
|
97
|
+
dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
|
97
98
|
|
98
99
|
rbln_config = RBLNConfig(
|
99
100
|
rbln_cls=cls.__name__,
|
100
|
-
compile_cfgs=[
|
101
|
+
compile_cfgs=[enc_compile_config, dec_compile_config],
|
101
102
|
rbln_kwargs=rbln_kwargs,
|
102
103
|
)
|
103
104
|
|
@@ -108,7 +109,10 @@ class RBLNBartModel(RBLNModel):
|
|
108
109
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
109
110
|
@classmethod
|
110
111
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
111
|
-
|
112
|
+
enc_max_seq_len = (
|
113
|
+
rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
|
114
|
+
)
|
115
|
+
return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
|
112
116
|
|
113
117
|
def __getattr__(self, __name: str) -> Any:
|
114
118
|
def redirect(func):
|
@@ -34,9 +34,9 @@ from transformers import (
|
|
34
34
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
35
35
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
36
36
|
|
37
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
37
38
|
from ....modeling import RBLNModel
|
38
39
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
39
|
-
from ....modeling_diffusers import RBLNDiffusionMixin
|
40
40
|
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
@@ -21,11 +21,4 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .decoderonly_architecture import (
|
25
|
-
DecoderOnlyWrapper,
|
26
|
-
RotaryEmbedding,
|
27
|
-
apply_rotary_pos_emb,
|
28
|
-
rotate_half,
|
29
|
-
slice_and_unsqueeze_cos_sin,
|
30
|
-
)
|
31
24
|
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|