optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -21,497 +21,140 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
from torch import nn
|
28
28
|
from transformers.modeling_attn_mask_utils import (
|
29
29
|
_prepare_4d_attention_mask,
|
30
|
-
_prepare_4d_attention_mask_for_sdpa,
|
31
|
-
_prepare_4d_causal_attention_mask,
|
32
|
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
33
|
-
)
|
34
|
-
from transformers.modeling_outputs import (
|
35
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
36
|
-
)
|
37
|
-
from transformers.models.bart.modeling_bart import (
|
38
|
-
BartAttention,
|
39
|
-
BartDecoder,
|
40
|
-
BartDecoderLayer,
|
41
|
-
BartForConditionalGeneration,
|
42
|
-
BartSdpaAttention,
|
43
30
|
)
|
44
31
|
from transformers.utils import logging
|
45
32
|
|
33
|
+
from ..seq2seq.seq2seq_architecture import (
|
34
|
+
Seq2SeqDecoder,
|
35
|
+
Seq2SeqDecoderLayer,
|
36
|
+
Seq2SeqDecoderWrapper,
|
37
|
+
Seq2SeqEncoderWrapper,
|
38
|
+
Seq2SeqForConditionalGeneration,
|
39
|
+
Seq2SeqSelfAttention,
|
40
|
+
)
|
41
|
+
|
46
42
|
|
47
43
|
logger = logging.get_logger(__name__)
|
48
44
|
|
49
45
|
|
50
46
|
class BartWrapper:
|
51
|
-
def __init__(self, model):
|
52
|
-
self.encoder =
|
47
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
48
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
53
49
|
self.decoder = BartDecoderWrapper(model)
|
54
50
|
|
55
51
|
|
56
|
-
class
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
cache_position: torch.Tensor,
|
63
|
-
batch_index: torch.Tensor,
|
64
|
-
key_value_states: Optional[torch.Tensor] = None,
|
65
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
66
|
-
bsz, tgt_len, _ = hidden_states.size()
|
67
|
-
is_cross_attention = key_value_states is not None
|
52
|
+
class BartDecoderWrapper(Seq2SeqDecoderWrapper):
|
53
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
54
|
+
new_layers = []
|
55
|
+
for layer in model.get_decoder().layers:
|
56
|
+
self_attn = BartSelfAttention(layer.self_attn)
|
57
|
+
new_layers.append(BartDecoderLayer(layer, self_attn))
|
68
58
|
|
69
|
-
|
59
|
+
decoder_model = BartDecoder(model.get_decoder(), new_layers)
|
60
|
+
new_model = BartForConditionalGeneration(model, decoder_model)
|
70
61
|
|
71
|
-
|
72
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
73
|
-
if is_dummy_decoder:
|
74
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
75
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
76
|
-
else:
|
77
|
-
key_states = past_key_value[0]
|
78
|
-
value_states = past_key_value[1]
|
79
|
-
else:
|
80
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
81
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
82
|
-
|
83
|
-
if cache_position.dim() > 0:
|
84
|
-
proj_shape = (bsz, self.num_heads, -1, self.head_dim)
|
85
|
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
86
|
-
key_states = key_states.reshape(*proj_shape)
|
87
|
-
value_states = value_states.reshape(*proj_shape)
|
88
|
-
|
89
|
-
all_key_states = []
|
90
|
-
all_value_states = []
|
91
|
-
all_attn_output = []
|
92
|
-
for b in range(bsz):
|
93
|
-
batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
|
94
|
-
batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
|
95
|
-
batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
|
96
|
-
batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
|
97
|
-
if not is_cross_attention:
|
98
|
-
batch_key_states = (
|
99
|
-
past_key_value[0][b]
|
100
|
-
.unsqueeze(0)
|
101
|
-
.unsqueeze(2)
|
102
|
-
.slice_scatter(
|
103
|
-
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
104
|
-
)
|
105
|
-
)
|
106
|
-
batch_value_states = (
|
107
|
-
past_key_value[1][b]
|
108
|
-
.unsqueeze(0)
|
109
|
-
.unsqueeze(2)
|
110
|
-
.slice_scatter(
|
111
|
-
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
112
|
-
)
|
113
|
-
)
|
114
|
-
attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
|
115
|
-
attn_weights = attn_weights + batch_attention_mask
|
116
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
117
|
-
|
118
|
-
attn_output = torch.matmul(attn_weights, batch_value_states)
|
119
|
-
attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
|
120
|
-
attn_output = attn_output.transpose(1, 2)
|
121
|
-
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
122
|
-
all_key_states.append(batch_key_states)
|
123
|
-
all_value_states.append(batch_value_states)
|
124
|
-
all_attn_output.append(attn_output)
|
125
|
-
key_states = torch.cat(all_key_states, dim=0).squeeze(2)
|
126
|
-
value_states = torch.cat(all_value_states, dim=0).squeeze(2)
|
127
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
62
|
+
return new_model
|
128
63
|
|
129
|
-
else:
|
130
|
-
if batch_index is None or batch_index == -1:
|
131
|
-
batch_index = 0
|
132
|
-
|
133
|
-
if not is_cross_attention:
|
134
|
-
key_states = past_key_value[0].slice_scatter(
|
135
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
136
|
-
)
|
137
|
-
value_states = past_key_value[1].slice_scatter(
|
138
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
139
|
-
)
|
140
|
-
|
141
|
-
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
142
|
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
143
|
-
key_states = key_states.reshape(*proj_shape)
|
144
|
-
value_states = value_states.reshape(*proj_shape)
|
145
|
-
|
146
|
-
src_len = key_states.size(1)
|
147
|
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
148
|
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
149
|
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
150
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
151
|
-
|
152
|
-
attn_output = torch.bmm(attn_weights, value_states)
|
153
|
-
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
154
|
-
attn_output = attn_output.transpose(1, 2)
|
155
|
-
key_states = key_states.unsqueeze(0)
|
156
|
-
value_states = value_states.unsqueeze(0)
|
157
|
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
158
|
-
|
159
|
-
attn_output = self.out_proj(attn_output)
|
160
|
-
|
161
|
-
present_key_value = (key_states, value_states)
|
162
|
-
|
163
|
-
return attn_output, present_key_value
|
164
|
-
|
165
|
-
|
166
|
-
class _BartSdpaAttention(BartSdpaAttention):
|
167
|
-
def forward(
|
168
|
-
self,
|
169
|
-
hidden_states: torch.Tensor,
|
170
|
-
past_key_value: Tuple[torch.Tensor],
|
171
|
-
attention_mask: torch.Tensor,
|
172
|
-
cache_position: torch.Tensor,
|
173
|
-
batch_index: torch.Tensor,
|
174
|
-
key_value_states: Optional[torch.Tensor] = None,
|
175
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
176
|
-
bsz, tgt_len, _ = hidden_states.size()
|
177
|
-
is_cross_attention = key_value_states is not None
|
178
|
-
|
179
|
-
query_states = self.q_proj(hidden_states)
|
180
|
-
|
181
|
-
if is_cross_attention:
|
182
|
-
is_dummy_decoder = len(key_value_states.shape) > 1
|
183
|
-
if is_dummy_decoder:
|
184
|
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
185
|
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
186
|
-
else:
|
187
|
-
key_states = past_key_value[0]
|
188
|
-
value_states = past_key_value[1]
|
189
|
-
else:
|
190
|
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
191
|
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
192
|
-
|
193
|
-
query_states = self._shape(query_states, tgt_len, bsz)
|
194
|
-
|
195
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
196
|
-
all_key_states = []
|
197
|
-
all_value_states = []
|
198
|
-
all_attn_output = []
|
199
|
-
|
200
|
-
for b in range(bsz):
|
201
|
-
batch_query_states = query_states[b].unsqueeze(0)
|
202
|
-
batch_attention_mask = attention_mask[b].unsqueeze(0)
|
203
|
-
batch_key_states = key_states[b].unsqueeze(0)
|
204
|
-
batch_value_states = value_states[b].unsqueeze(0)
|
205
|
-
|
206
|
-
if not is_cross_attention:
|
207
|
-
batch_key_states = (
|
208
|
-
past_key_value[0][b]
|
209
|
-
.unsqueeze(0)
|
210
|
-
.slice_scatter(
|
211
|
-
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
212
|
-
)
|
213
|
-
)
|
214
|
-
batch_value_states = (
|
215
|
-
past_key_value[1][b]
|
216
|
-
.unsqueeze(0)
|
217
|
-
.slice_scatter(
|
218
|
-
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
219
|
-
)
|
220
|
-
)
|
221
|
-
|
222
|
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
223
|
-
batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
|
224
|
-
)
|
225
|
-
attn_output = attn_output.transpose(1, 2)
|
226
|
-
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
227
|
-
all_key_states.append(batch_key_states)
|
228
|
-
all_value_states.append(batch_value_states)
|
229
|
-
all_attn_output.append(attn_output)
|
230
|
-
|
231
|
-
key_states = torch.cat(all_key_states, dim=0)
|
232
|
-
value_states = torch.cat(all_value_states, dim=0)
|
233
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
234
64
|
|
65
|
+
class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
66
|
+
has_rescaling = False
|
67
|
+
|
68
|
+
def __post_init__(self):
|
69
|
+
self.scaling = self.config.d_model**-0.5
|
70
|
+
|
71
|
+
|
72
|
+
class BartDecoder(Seq2SeqDecoder):
|
73
|
+
has_pos_emb = True
|
74
|
+
|
75
|
+
def __post_init__(self):
|
76
|
+
self.embed_positions = self._original_mod.embed_positions
|
77
|
+
self.layernorm_embedding = self._original_mod.layernorm_embedding
|
78
|
+
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
79
|
+
|
80
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
81
|
+
attention_mask = attention_mask[:, None, None, :]
|
82
|
+
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
83
|
+
|
84
|
+
return attention_mask, encoder_attention_mask
|
85
|
+
|
86
|
+
def apply_position_embedding(self, inputs_embeds, cache_position):
|
87
|
+
hidden_all = []
|
88
|
+
for i in range(inputs_embeds.shape[0]):
|
89
|
+
positions_idx = cache_position[i]
|
90
|
+
position_weight = self.embed_positions.weight[2:]
|
91
|
+
position = position_weight[positions_idx]
|
92
|
+
batch_hidden = position + inputs_embeds[i]
|
93
|
+
hidden_all.append(batch_hidden)
|
94
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
95
|
+
|
96
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
97
|
+
|
98
|
+
return hidden_states
|
99
|
+
|
100
|
+
def get_embedding(self):
|
101
|
+
if self.embed_scale is not None:
|
102
|
+
return lambda x: self.embed_tokens(x) * self.embed_scale
|
235
103
|
else:
|
236
|
-
|
237
|
-
batch_index = 0
|
238
|
-
|
239
|
-
if not is_cross_attention:
|
240
|
-
key_states = past_key_value[0].slice_scatter(
|
241
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
242
|
-
)
|
243
|
-
value_states = past_key_value[1].slice_scatter(
|
244
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
245
|
-
)
|
246
|
-
|
247
|
-
# need 4d shape (input tensors) for scaled_dot_product_attention
|
248
|
-
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
249
|
-
query_states,
|
250
|
-
key_states,
|
251
|
-
value_states,
|
252
|
-
attn_mask=attention_mask,
|
253
|
-
)
|
254
|
-
attn_output = attn_output.transpose(1, 2)
|
255
|
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
256
|
-
|
257
|
-
attn_output = self.out_proj(attn_output)
|
258
|
-
|
259
|
-
present_key_value = (key_states, value_states)
|
260
|
-
|
261
|
-
return attn_output, present_key_value
|
262
|
-
|
263
|
-
|
264
|
-
ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
|
265
|
-
|
266
|
-
|
267
|
-
class _BartDecoderLayer(BartDecoderLayer):
|
268
|
-
def forward(
|
269
|
-
self,
|
270
|
-
hidden_states: torch.Tensor,
|
271
|
-
attention_mask: torch.Tensor,
|
272
|
-
encoder_attention_mask: torch.Tensor,
|
273
|
-
encoder_hidden_states: torch.Tensor,
|
274
|
-
past_key_value: Tuple[torch.Tensor],
|
275
|
-
cache_position: torch.Tensor,
|
276
|
-
batch_ids: torch.Tensor,
|
277
|
-
attn_impl: str = "eager",
|
278
|
-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
279
|
-
# Self Attention Block
|
280
|
-
residual = hidden_states
|
281
|
-
self_attn_past_key_value = past_key_value[:2]
|
282
|
-
|
283
|
-
hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
284
|
-
self.self_attn,
|
285
|
-
hidden_states=hidden_states,
|
286
|
-
past_key_value=self_attn_past_key_value,
|
287
|
-
attention_mask=attention_mask,
|
288
|
-
cache_position=cache_position,
|
289
|
-
batch_index=batch_ids,
|
290
|
-
)
|
291
|
-
hidden_states = residual + hidden_states
|
292
|
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
104
|
+
return self.embed_tokens
|
293
105
|
|
294
|
-
# Cross-Attention Block
|
295
|
-
residual = hidden_states
|
296
|
-
cross_attn_past_key_value = past_key_value[-2:]
|
297
|
-
|
298
|
-
hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
299
|
-
self.encoder_attn,
|
300
|
-
hidden_states=hidden_states,
|
301
|
-
key_value_states=encoder_hidden_states,
|
302
|
-
past_key_value=cross_attn_past_key_value,
|
303
|
-
attention_mask=encoder_attention_mask,
|
304
|
-
cache_position=cache_position,
|
305
|
-
batch_index=batch_ids,
|
306
|
-
)
|
307
|
-
hidden_states = residual + hidden_states
|
308
|
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
309
|
-
present_key_value = present_key_value + cross_attn_present_key_value
|
310
106
|
|
311
|
-
|
107
|
+
class BartLayerFF(nn.Module):
|
108
|
+
def __init__(self, decoder_layer):
|
109
|
+
super().__init__()
|
110
|
+
self.fc1 = decoder_layer.fc1
|
111
|
+
self.fc2 = decoder_layer.fc2
|
112
|
+
self.activation_fn = decoder_layer.activation_fn
|
113
|
+
self.layer_norm = decoder_layer.final_layer_norm
|
114
|
+
|
115
|
+
def forward(self, hidden_states):
|
116
|
+
# Residual Connection
|
312
117
|
residual = hidden_states
|
313
118
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
314
119
|
hidden_states = self.fc2(hidden_states)
|
315
120
|
hidden_states = residual + hidden_states
|
316
|
-
hidden_states = self.
|
317
|
-
|
318
|
-
return hidden_states, present_key_value
|
319
|
-
|
320
|
-
|
321
|
-
class _BartDecoder(BartDecoder):
|
322
|
-
def forward(
|
323
|
-
self,
|
324
|
-
input_ids: torch.Tensor,
|
325
|
-
attention_mask: torch.Tensor,
|
326
|
-
encoder_attention_mask: torch.Tensor,
|
327
|
-
encoder_hidden_states: torch.Tensor,
|
328
|
-
past_key_values: torch.Tensor,
|
329
|
-
cache_position: torch.Tensor,
|
330
|
-
batch_ids: torch.Tensor,
|
331
|
-
attn_impl: str = "eager",
|
332
|
-
):
|
333
|
-
# embedding
|
334
|
-
if hasattr(self, "embed_scale"):
|
335
|
-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
336
|
-
else:
|
337
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
121
|
+
hidden_states = self.layer_norm(hidden_states)
|
122
|
+
return hidden_states
|
338
123
|
|
339
|
-
if cache_position.dim() == 0:
|
340
|
-
positions_idx = cache_position + self.embed_positions.offset
|
341
|
-
positions = self.embed_positions.weight[positions_idx]
|
342
|
-
hidden_states = inputs_embeds + positions
|
343
|
-
else:
|
344
|
-
hidden_all = []
|
345
|
-
# compiler pattern base dependency -> take + add
|
346
|
-
for i in range(input_ids.shape[0]):
|
347
|
-
# cache position [N,1]
|
348
|
-
positions_idx = cache_position[i]
|
349
|
-
# offset is set 2 in bart embedding
|
350
|
-
position_weight = self.embed_positions.weight[2:]
|
351
|
-
position = position_weight[positions_idx]
|
352
|
-
batch_hidden = position + inputs_embeds[i]
|
353
|
-
hidden_all.append(batch_hidden)
|
354
|
-
hidden_states = torch.stack(hidden_all, dim=0)
|
355
124
|
|
356
|
-
|
125
|
+
class BartDecoderLayer(Seq2SeqDecoderLayer):
|
126
|
+
def __post_init__(self):
|
127
|
+
self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
|
128
|
+
self.encoder_attn = self._original_mod.encoder_attn
|
129
|
+
self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
|
130
|
+
self.ff_layer = BartLayerFF(self._original_mod)
|
357
131
|
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
batch_ids=batch_ids,
|
388
|
-
attn_impl=attn_impl,
|
389
|
-
)
|
390
|
-
hidden_states = layer_outputs[0]
|
391
|
-
next_decoder_cache += (layer_outputs[1],)
|
392
|
-
|
393
|
-
return BaseModelOutputWithPastAndCrossAttentions(
|
394
|
-
last_hidden_state=hidden_states,
|
395
|
-
past_key_values=next_decoder_cache,
|
396
|
-
)
|
397
|
-
|
398
|
-
|
399
|
-
class BartDecoderWrapper(torch.nn.Module):
|
400
|
-
def __init__(self, model: "BartForConditionalGeneration"):
|
401
|
-
super().__init__()
|
402
|
-
self.config = model.config
|
403
|
-
self.decoder = model.get_decoder()
|
404
|
-
self.num_layers = self.config.decoder_layers
|
405
|
-
self.lm_head = model.lm_head
|
406
|
-
|
407
|
-
def forward(
|
408
|
-
self,
|
409
|
-
input_ids: torch.Tensor,
|
410
|
-
attention_mask: torch.Tensor,
|
411
|
-
encoder_attention_mask: torch.Tensor,
|
412
|
-
cache_position: torch.Tensor,
|
413
|
-
batch_position: torch.Tensor,
|
414
|
-
self_kv_cache: torch.Tensor,
|
415
|
-
cross_kv_cache: torch.Tensor,
|
416
|
-
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
417
|
-
if input_ids.shape[1] == 1:
|
418
|
-
rbln_batch_position = None
|
419
|
-
else:
|
420
|
-
rbln_batch_position = batch_position
|
421
|
-
# prepare past_key_values
|
422
|
-
kv_cache = ()
|
423
|
-
for i in range(0, self.num_layers * 2, 2):
|
424
|
-
kv_cache = kv_cache + (
|
425
|
-
(
|
426
|
-
self_kv_cache[i],
|
427
|
-
self_kv_cache[i + 1],
|
428
|
-
cross_kv_cache[i],
|
429
|
-
cross_kv_cache[i + 1],
|
430
|
-
),
|
431
|
-
)
|
432
|
-
# decode
|
433
|
-
decoder_outputs = _BartDecoder.forward(
|
434
|
-
self.decoder,
|
435
|
-
input_ids=input_ids,
|
436
|
-
attention_mask=attention_mask,
|
437
|
-
encoder_attention_mask=encoder_attention_mask,
|
438
|
-
cache_position=cache_position,
|
439
|
-
past_key_values=kv_cache,
|
440
|
-
encoder_hidden_states=torch.tensor([1]),
|
441
|
-
attn_impl=self.config._attn_implementation,
|
442
|
-
batch_ids=rbln_batch_position,
|
443
|
-
)
|
444
|
-
sequence_output = decoder_outputs[0]
|
445
|
-
lm_logits = self.lm_head(sequence_output)
|
446
|
-
|
447
|
-
# get self_kv_cache from ouputs
|
448
|
-
past_key_values = decoder_outputs[1]
|
449
|
-
self_kv_cache = []
|
450
|
-
for i in range(self.num_layers):
|
451
|
-
self_kv_cache.append(past_key_values[i][0])
|
452
|
-
self_kv_cache.append(past_key_values[i][1])
|
453
|
-
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
454
|
-
|
455
|
-
# return batch_position to keep it as a variable within the graph
|
456
|
-
return lm_logits, self_kv_cache, batch_position
|
457
|
-
|
458
|
-
|
459
|
-
class BartEncoderWrapper(torch.nn.Module):
|
460
|
-
def __init__(self, model):
|
461
|
-
super().__init__()
|
462
|
-
self.model = model
|
463
|
-
self.config = model.config
|
464
|
-
self.decoder = model.get_decoder()
|
465
|
-
self.encoder = model.get_encoder()
|
466
|
-
self.num_layers = self.config.encoder_layers
|
467
|
-
self.decoder_max_length = self.config.max_position_embeddings
|
468
|
-
self.encoder_max_length = self.config.max_position_embeddings
|
469
|
-
self.num_heads = self.config.decoder_attention_heads
|
470
|
-
self.d_kv = self.config.d_model // self.num_heads
|
471
|
-
|
472
|
-
def forward(
|
473
|
-
self,
|
474
|
-
input_ids: torch.LongTensor,
|
475
|
-
attention_mask: torch.LongTensor,
|
476
|
-
cross_key_value: torch.Tensor = None,
|
477
|
-
batch_idx: torch.Tensor = None,
|
478
|
-
) -> Tuple[torch.Tensor]:
|
479
|
-
# 1. run encoder
|
480
|
-
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
481
|
-
last_hidden_states = encoder_outputs[0]
|
482
|
-
|
483
|
-
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
484
|
-
dummy_past_key_value = []
|
485
|
-
for _ in range(self.num_layers):
|
486
|
-
pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
487
|
-
pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
488
|
-
pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
489
|
-
pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
490
|
-
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
491
|
-
dummy_past_key_value.append(layer_pkv)
|
492
|
-
|
493
|
-
decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
|
494
|
-
decoder_attention_mask[:, :1] = 1
|
495
|
-
|
496
|
-
decoder_outputs = _BartDecoder.forward(
|
497
|
-
self.decoder,
|
498
|
-
input_ids=torch.zeros((1, 1), dtype=torch.int64),
|
499
|
-
attention_mask=decoder_attention_mask,
|
500
|
-
encoder_attention_mask=attention_mask,
|
501
|
-
cache_position=torch.tensor(0, dtype=torch.int32),
|
502
|
-
encoder_hidden_states=last_hidden_states,
|
503
|
-
past_key_values=dummy_past_key_value,
|
504
|
-
batch_ids=torch.tensor(0, dtype=torch.int32),
|
505
|
-
attn_impl=self.config._attn_implementation,
|
506
|
-
)
|
507
|
-
first_past_kv = decoder_outputs[1]
|
508
|
-
|
509
|
-
encoder_kv = []
|
510
|
-
for i in range(self.model.config.decoder_layers):
|
511
|
-
encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
|
512
|
-
encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
|
513
|
-
encoder_kv = torch.cat(encoder_kv, dim=0)
|
514
|
-
|
515
|
-
cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
|
516
|
-
|
517
|
-
return cross_key_value
|
132
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
133
|
+
return hidden_states
|
134
|
+
|
135
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
136
|
+
return self.self_attn_layer_norm(hidden_states)
|
137
|
+
|
138
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
139
|
+
return hidden_states
|
140
|
+
|
141
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
142
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
143
|
+
|
144
|
+
|
145
|
+
class BartSelfAttention(Seq2SeqSelfAttention):
|
146
|
+
def __post_init__(self):
|
147
|
+
self.q_proj = self._original_mod.q_proj
|
148
|
+
self.k_proj = self._original_mod.k_proj
|
149
|
+
self.v_proj = self._original_mod.v_proj
|
150
|
+
self.out_proj = self._original_mod.out_proj
|
151
|
+
self.num_heads = self._original_mod.num_heads
|
152
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
153
|
+
self.scaling = self.head_dim**-0.5
|
154
|
+
self.attn_decode = torch.ops.rbln_custom_ops.attn_decode
|
155
|
+
|
156
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
157
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
158
|
+
key_states = self.k_proj(hidden_states)
|
159
|
+
value_states = self.v_proj(hidden_states)
|
160
|
+
return query_states, key_states, value_states
|
@@ -24,9 +24,9 @@
|
|
24
24
|
import inspect
|
25
25
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
26
|
|
27
|
-
from transformers import
|
27
|
+
from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
28
28
|
|
29
|
-
from ....
|
29
|
+
from ....modeling import RBLNModel
|
30
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
@@ -41,9 +41,6 @@ if TYPE_CHECKING:
|
|
41
41
|
|
42
42
|
|
43
43
|
class RBLNBartModel(RBLNModel):
|
44
|
-
original_model_class = BartModel
|
45
|
-
original_config_class = BartConfig
|
46
|
-
|
47
44
|
@classmethod
|
48
45
|
def _get_rbln_config(
|
49
46
|
cls,
|
@@ -82,7 +79,7 @@ class RBLNBartModel(RBLNModel):
|
|
82
79
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
83
80
|
rbln_model_input_names = cls.rbln_model_input_names
|
84
81
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
85
|
-
input_names_order = inspect.signature(cls.
|
82
|
+
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
86
83
|
raise ValueError(
|
87
84
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
88
85
|
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
|
@@ -96,11 +93,12 @@ class RBLNBartModel(RBLNModel):
|
|
96
93
|
for model_input_name in rbln_model_input_names
|
97
94
|
]
|
98
95
|
|
99
|
-
|
96
|
+
enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
|
97
|
+
dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
|
100
98
|
|
101
99
|
rbln_config = RBLNConfig(
|
102
100
|
rbln_cls=cls.__name__,
|
103
|
-
compile_cfgs=[
|
101
|
+
compile_cfgs=[enc_compile_config, dec_compile_config],
|
104
102
|
rbln_kwargs=rbln_kwargs,
|
105
103
|
)
|
106
104
|
|
@@ -111,7 +109,10 @@ class RBLNBartModel(RBLNModel):
|
|
111
109
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
112
110
|
@classmethod
|
113
111
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
114
|
-
|
112
|
+
enc_max_seq_len = (
|
113
|
+
rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
|
114
|
+
)
|
115
|
+
return BartWrapper(model, enc_max_seq_len=enc_max_seq_len)
|
115
116
|
|
116
117
|
def __getattr__(self, __name: str) -> Any:
|
117
118
|
def redirect(func):
|