optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,498 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Tuple
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.utils import logging
|
29
|
+
|
30
|
+
from ....ops import register_rbln_custom_attention, register_rbln_custom_cache_update
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class Seq2SeqWrapper:
|
37
|
+
"""A wrapper class for Seq2Seq models to support RBLN-specific optimizations.
|
38
|
+
|
39
|
+
This wrapper divides the Seq2Seq model into separate encoder and decoder wrappers,
|
40
|
+
enabling specific optimizations such as custom cache handling and attention mechanisms.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
model (nn.Module): The Seq2Seq model to wrap.
|
44
|
+
enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
|
45
|
+
**kwargs: Additional arguments to pass to the decoder wrapper.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
|
49
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
50
|
+
self.decoder = Seq2SeqDecoderWrapper(model, **kwargs)
|
51
|
+
|
52
|
+
|
53
|
+
class Seq2SeqEncoderWrapper(nn.Module):
|
54
|
+
"""A wrapper for the encoder component of a Seq2Seq model, designed for RBLN optimization.
|
55
|
+
|
56
|
+
This wrapper modifies the standard encoder-decoder architecture of Seq2Seq models to optimize
|
57
|
+
memory usage and attention mechanisms, particularly in cross-attention layers. It supports custom
|
58
|
+
cache handling to improve performance during decoding.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
model (nn.Module): The Seq2Seq model containing the encoder.
|
62
|
+
enc_max_seq_len (int): Maximum sequence length for encoder embeddings and cache sizes.
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
66
|
+
super().__init__()
|
67
|
+
register_rbln_custom_cache_update()
|
68
|
+
self.config = model.config
|
69
|
+
self.encoder = model.get_encoder()
|
70
|
+
self.encoder_max_length = enc_max_seq_len
|
71
|
+
self.__post_init__(model)
|
72
|
+
|
73
|
+
def __post_init__(self, model: nn.Module):
|
74
|
+
"""
|
75
|
+
Post-initialization to extract and configure encoder-related attributes.
|
76
|
+
|
77
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
78
|
+
by subclasses to modify or add custom attributes as necessary.
|
79
|
+
"""
|
80
|
+
self.n_layer = getattr(self.config, "decoder_layers", None)
|
81
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
82
|
+
self.num_heads = self.config.decoder_attention_heads
|
83
|
+
self.d_kv = self.config.d_model // self.num_heads
|
84
|
+
|
85
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
86
|
+
"""
|
87
|
+
Extract cross-attention key and value projection layers from the decoder.
|
88
|
+
"""
|
89
|
+
return (
|
90
|
+
nn.ModuleList(decoder_layers[i].encoder_attn.k_proj for i in range(self.n_layer)),
|
91
|
+
nn.ModuleList(decoder_layers[i].encoder_attn.v_proj for i in range(self.n_layer)),
|
92
|
+
)
|
93
|
+
|
94
|
+
def forward(
|
95
|
+
self,
|
96
|
+
input_ids: torch.Tensor,
|
97
|
+
attention_mask: torch.Tensor,
|
98
|
+
cross_key_values: torch.Tensor,
|
99
|
+
batch_position: torch.Tensor,
|
100
|
+
) -> Tuple[torch.Tensor]:
|
101
|
+
# 1. get encoder last_hidden_states
|
102
|
+
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
103
|
+
last_hidden_states = encoder_outputs[0]
|
104
|
+
|
105
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
106
|
+
cross_kv = []
|
107
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
108
|
+
past_k = (
|
109
|
+
k_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
|
110
|
+
)
|
111
|
+
past_v = (
|
112
|
+
v_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
|
113
|
+
)
|
114
|
+
|
115
|
+
cross_kv.append(past_k)
|
116
|
+
cross_kv.append(past_v)
|
117
|
+
|
118
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
119
|
+
|
120
|
+
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
121
|
+
batch_axis = torch.tensor(1, dtype=torch.int16)
|
122
|
+
cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
|
123
|
+
cross_key_values, cross_kv, batch_position, batch_axis
|
124
|
+
)
|
125
|
+
|
126
|
+
return cross_key_values
|
127
|
+
|
128
|
+
|
129
|
+
class Seq2SeqDecoderWrapper(nn.Module):
|
130
|
+
"""
|
131
|
+
A wrapper for the decoder component of a Seq2Seq model, designed for RBLN optimization.
|
132
|
+
|
133
|
+
This wrapper handles tasks such as:
|
134
|
+
1. Converting decoder components to support RBLN-specific conditional generation.
|
135
|
+
2. Customizing attention mechanisms, including self-attention and cross-attention.
|
136
|
+
3. Managing the decoder's key-value caches for both self and cross-attention.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
model (nn.Module): The Seq2Seq model containing the decoder.
|
140
|
+
**kwargs: Additional arguments for decoder configuration.
|
141
|
+
"""
|
142
|
+
|
143
|
+
def __init__(self, model: nn.Module, **kwargs):
|
144
|
+
super().__init__()
|
145
|
+
self.config = model.config
|
146
|
+
self.__post_init__(model, **kwargs)
|
147
|
+
|
148
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
149
|
+
"""
|
150
|
+
Post-initialization to extract and configure encoder-related attributes.
|
151
|
+
|
152
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
153
|
+
by subclasses to modify or add custom attributes as necessary.
|
154
|
+
"""
|
155
|
+
register_rbln_custom_attention()
|
156
|
+
self.num_layers = self.config.decoder_layers
|
157
|
+
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
158
|
+
|
159
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
160
|
+
new_layers = []
|
161
|
+
for layer in model.get_decoder().layers:
|
162
|
+
self_attn = Seq2SeqSelfAttention(layer.self_attn)
|
163
|
+
new_layers.append(Seq2SeqDecoderLayer(layer, self_attn))
|
164
|
+
|
165
|
+
decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
|
166
|
+
new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
|
167
|
+
|
168
|
+
return new_model
|
169
|
+
|
170
|
+
def forward(
|
171
|
+
self,
|
172
|
+
input_ids: torch.Tensor,
|
173
|
+
attention_mask: torch.Tensor,
|
174
|
+
encoder_attention_mask: torch.Tensor,
|
175
|
+
cache_position: torch.Tensor,
|
176
|
+
cross_kv_cache: torch.Tensor,
|
177
|
+
*self_kv_cache: torch.Tensor,
|
178
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
179
|
+
self_past_key_values = ()
|
180
|
+
cross_past_key_values = ()
|
181
|
+
for i in range(0, self.num_layers * 2, 2):
|
182
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
183
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
184
|
+
|
185
|
+
# decode
|
186
|
+
lm_logits, self_present_key_values = self.conditional_generation(
|
187
|
+
input_ids=input_ids,
|
188
|
+
attention_mask=attention_mask,
|
189
|
+
encoder_attention_mask=encoder_attention_mask,
|
190
|
+
self_past_key_values=self_past_key_values,
|
191
|
+
cross_past_key_values=cross_past_key_values,
|
192
|
+
cache_position=cache_position,
|
193
|
+
)
|
194
|
+
|
195
|
+
outputs = (lm_logits,) + self_present_key_values
|
196
|
+
|
197
|
+
return outputs
|
198
|
+
|
199
|
+
|
200
|
+
class Seq2SeqForConditionalGeneration(nn.Module):
|
201
|
+
"""
|
202
|
+
A wrapper for Seq2Seq models supporting RBLN-specific optimizations for conditional generation.
|
203
|
+
|
204
|
+
This class adapts a Seq2Seq model for tasks like machine translation, summarization, or text generation
|
205
|
+
by:
|
206
|
+
1. Wrapping and customizing the decoder component to support key RBLN features.
|
207
|
+
2. Managing rescaling and output processing, if enabled.
|
208
|
+
3. Aligning model behavior with RBLN's static and efficient execution requirements.
|
209
|
+
|
210
|
+
Attributes:
|
211
|
+
has_rescaling (bool): Indicates if output rescaling is applied.
|
212
|
+
config (PretrainedConfig): Configuration from the original Seq2Seq model.
|
213
|
+
lm_head (nn.Linear): The language modeling head for output logits.
|
214
|
+
decoder (nn.Module): The wrapped decoder model.
|
215
|
+
"""
|
216
|
+
|
217
|
+
has_rescaling = False
|
218
|
+
|
219
|
+
def __init__(self, model, decoder_model):
|
220
|
+
super().__init__()
|
221
|
+
self.config = model.config
|
222
|
+
self.lm_head = model.lm_head
|
223
|
+
self.decoder = decoder_model
|
224
|
+
self.__post_init__()
|
225
|
+
|
226
|
+
def __post_init__(self):
|
227
|
+
"""
|
228
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
229
|
+
the attributes of the original model after initialization.
|
230
|
+
"""
|
231
|
+
|
232
|
+
def forward(
|
233
|
+
self,
|
234
|
+
input_ids,
|
235
|
+
attention_mask,
|
236
|
+
encoder_attention_mask,
|
237
|
+
self_past_key_values,
|
238
|
+
cross_past_key_values,
|
239
|
+
cache_position,
|
240
|
+
):
|
241
|
+
hidden_states, self_present_key_values = self.decoder(
|
242
|
+
input_ids=input_ids,
|
243
|
+
attention_mask=attention_mask,
|
244
|
+
encoder_attention_mask=encoder_attention_mask,
|
245
|
+
self_past_key_values=self_past_key_values,
|
246
|
+
cross_past_key_values=cross_past_key_values,
|
247
|
+
cache_position=cache_position,
|
248
|
+
)
|
249
|
+
|
250
|
+
if self.has_rescaling and self.config.tie_word_embeddings:
|
251
|
+
hidden_states = hidden_states * self.scaling
|
252
|
+
|
253
|
+
lm_logits = self.lm_head(hidden_states)
|
254
|
+
|
255
|
+
return lm_logits, self_present_key_values
|
256
|
+
|
257
|
+
|
258
|
+
class Seq2SeqDecoder(torch.nn.Module):
|
259
|
+
"""A modified Seq2SeqDecoder implementation optimized for RBLN compilation.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
model: Original Huggingface model to adapt
|
263
|
+
layers (List[Seq2SeqDecoderLayer]): Modified transformer layers optimized for RBLN
|
264
|
+
"""
|
265
|
+
|
266
|
+
has_pos_emb = True
|
267
|
+
|
268
|
+
def __init__(self, model, layers, **kwargs):
|
269
|
+
super().__init__()
|
270
|
+
self._original_mod = model
|
271
|
+
self.layers = nn.ModuleList(layers)
|
272
|
+
self.embed_tokens = model.embed_tokens
|
273
|
+
self.final_layer_norm = getattr(model, "final_layer_norm", None)
|
274
|
+
self.__post_init__(**kwargs)
|
275
|
+
|
276
|
+
def __post_init__(self, **kwargs):
|
277
|
+
"""
|
278
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
279
|
+
the attributes of the original model after initialization.
|
280
|
+
"""
|
281
|
+
pass
|
282
|
+
|
283
|
+
def get_embedding(self):
|
284
|
+
return self.embed_tokens
|
285
|
+
|
286
|
+
def prepare_attn_mask(self, *args, **kwargs):
|
287
|
+
raise NotImplementedError(
|
288
|
+
"The 'prepare_attn_mask' method is not implemented. Please define this method in a subclass."
|
289
|
+
)
|
290
|
+
|
291
|
+
def apply_position_embedding(self, *args, **kwargs):
|
292
|
+
raise NotImplementedError(
|
293
|
+
"The 'apply_position_embedding' method is not implemented. Please define this method in a subclass."
|
294
|
+
)
|
295
|
+
|
296
|
+
def forward(
|
297
|
+
self,
|
298
|
+
input_ids: torch.Tensor,
|
299
|
+
attention_mask: torch.Tensor,
|
300
|
+
encoder_attention_mask: torch.Tensor,
|
301
|
+
self_past_key_values: torch.Tensor,
|
302
|
+
cross_past_key_values: torch.Tensor,
|
303
|
+
cache_position: torch.Tensor,
|
304
|
+
):
|
305
|
+
# embedding
|
306
|
+
hidden_states = self.get_embedding()(input_ids)
|
307
|
+
attention_mask, encoder_attention_mask = self.prepare_attn_mask(
|
308
|
+
attention_mask, encoder_attention_mask, cache_position=cache_position
|
309
|
+
)
|
310
|
+
|
311
|
+
if self.has_pos_emb:
|
312
|
+
hidden_states = self.apply_position_embedding(hidden_states, cache_position)
|
313
|
+
|
314
|
+
# iterate decoder_layer
|
315
|
+
self_present_key_values = ()
|
316
|
+
for decoder_layer, self_past_key_value, cross_past_key_value in zip(
|
317
|
+
self.layers, self_past_key_values, cross_past_key_values
|
318
|
+
):
|
319
|
+
hidden_states, self_present_key_value = decoder_layer(
|
320
|
+
hidden_states,
|
321
|
+
attention_mask=attention_mask,
|
322
|
+
encoder_attention_mask=encoder_attention_mask,
|
323
|
+
self_past_key_value=self_past_key_value,
|
324
|
+
cross_past_key_value=cross_past_key_value,
|
325
|
+
cache_position=cache_position,
|
326
|
+
)
|
327
|
+
self_present_key_values += self_present_key_value
|
328
|
+
|
329
|
+
if self.final_layer_norm is not None:
|
330
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
331
|
+
|
332
|
+
return hidden_states, self_present_key_values
|
333
|
+
|
334
|
+
|
335
|
+
class Seq2SeqDecoderLayer(torch.nn.Module):
|
336
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
337
|
+
|
338
|
+
Args:
|
339
|
+
model: Original Huggingface model to adapt
|
340
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
341
|
+
self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
|
342
|
+
"""
|
343
|
+
|
344
|
+
def __init__(self, decoder_layer, self_attn):
|
345
|
+
super().__init__()
|
346
|
+
self._original_mod = decoder_layer
|
347
|
+
self.self_attn = self_attn
|
348
|
+
self.__post_init__()
|
349
|
+
|
350
|
+
def __post_init__(self, **kwargs):
|
351
|
+
"""
|
352
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
353
|
+
the attributes of the original model after initialization.
|
354
|
+
"""
|
355
|
+
pass
|
356
|
+
|
357
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
358
|
+
raise NotImplementedError(
|
359
|
+
"The 'pre_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
360
|
+
)
|
361
|
+
|
362
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
363
|
+
raise NotImplementedError(
|
364
|
+
"The 'post_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
365
|
+
)
|
366
|
+
|
367
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
368
|
+
raise NotImplementedError(
|
369
|
+
"The 'pre_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
370
|
+
)
|
371
|
+
|
372
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
373
|
+
raise NotImplementedError(
|
374
|
+
"The 'post_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
375
|
+
)
|
376
|
+
|
377
|
+
def forward(
|
378
|
+
self,
|
379
|
+
hidden_states: torch.Tensor,
|
380
|
+
attention_mask: torch.Tensor,
|
381
|
+
encoder_attention_mask: torch.Tensor,
|
382
|
+
self_past_key_value: Tuple[torch.Tensor],
|
383
|
+
cross_past_key_value: Tuple[torch.Tensor],
|
384
|
+
cache_position: torch.Tensor,
|
385
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
386
|
+
dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
|
387
|
+
|
388
|
+
# Self Attention Block
|
389
|
+
residual = hidden_states
|
390
|
+
hidden_states = self.pre_self_attn_layer_norm(hidden_states)
|
391
|
+
hidden_states, self_attn_past_key_value = self.self_attn(
|
392
|
+
hidden_states=hidden_states,
|
393
|
+
past_key_value=self_past_key_value,
|
394
|
+
attention_mask=attention_mask,
|
395
|
+
cache_position=cache_position,
|
396
|
+
)
|
397
|
+
hidden_states = residual + hidden_states
|
398
|
+
hidden_states = self.post_self_attn_layer_norm(hidden_states)
|
399
|
+
|
400
|
+
# Cross-Attention Block
|
401
|
+
residual = hidden_states
|
402
|
+
hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
|
403
|
+
cross_attn_output = self.encoder_attn(
|
404
|
+
hidden_states=hidden_states,
|
405
|
+
past_key_value=cross_past_key_value,
|
406
|
+
attention_mask=encoder_attention_mask,
|
407
|
+
key_value_states=dummy_encoder_hidden_states,
|
408
|
+
)
|
409
|
+
hidden_states = residual + cross_attn_output[0]
|
410
|
+
hidden_states = self.post_cross_attn_layer_norm(hidden_states)
|
411
|
+
|
412
|
+
# Feed-Forward Block
|
413
|
+
hidden_states = self.ff_layer(hidden_states)
|
414
|
+
|
415
|
+
return hidden_states, self_attn_past_key_value
|
416
|
+
|
417
|
+
|
418
|
+
class Seq2SeqSelfAttention(nn.Module):
|
419
|
+
def __init__(self, attn):
|
420
|
+
super().__init__()
|
421
|
+
self._original_mod = attn
|
422
|
+
self.__post_init__()
|
423
|
+
|
424
|
+
def __post_init__(self, **kwargs):
|
425
|
+
"""
|
426
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
427
|
+
the attributes of the original model after initialization.
|
428
|
+
"""
|
429
|
+
pass
|
430
|
+
|
431
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
432
|
+
return tensor.view(bsz, 1, seq_len, 1, self.num_heads, self.head_dim).transpose(2, 4)
|
433
|
+
|
434
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
435
|
+
"""Projects input hidden states into query, key, and value representations.
|
436
|
+
|
437
|
+
Args:
|
438
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
439
|
+
|
440
|
+
Returns:
|
441
|
+
Tuple of (query_states, key_states, value_states)
|
442
|
+
"""
|
443
|
+
query_states = self.q_proj(hidden_states)
|
444
|
+
key_states = self.k_proj(hidden_states)
|
445
|
+
value_states = self.v_proj(hidden_states)
|
446
|
+
return query_states, key_states, value_states
|
447
|
+
|
448
|
+
def forward(
|
449
|
+
self,
|
450
|
+
hidden_states: torch.Tensor,
|
451
|
+
past_key_value: Tuple[torch.Tensor],
|
452
|
+
attention_mask: torch.Tensor,
|
453
|
+
cache_position: torch.Tensor,
|
454
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
455
|
+
bsz, tgt_len, _ = hidden_states.size()
|
456
|
+
|
457
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
458
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
459
|
+
key_states = self._shape(key_states, -1, bsz)
|
460
|
+
value_states = self._shape(value_states, -1, bsz)
|
461
|
+
|
462
|
+
all_key_states = []
|
463
|
+
all_value_states = []
|
464
|
+
all_attn_output = []
|
465
|
+
for b_idx in range(bsz):
|
466
|
+
query_state = query_states[b_idx]
|
467
|
+
key_state = key_states[b_idx]
|
468
|
+
value_state = value_states[b_idx]
|
469
|
+
attn_mask = attention_mask[b_idx].unsqueeze(0).unsqueeze(2)
|
470
|
+
past_key_state = past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim)
|
471
|
+
past_value_state = past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim)
|
472
|
+
|
473
|
+
attn_output, key_state, value_state = self.attn_decode(
|
474
|
+
query_state,
|
475
|
+
key_state,
|
476
|
+
value_state,
|
477
|
+
attn_mask,
|
478
|
+
past_key_state,
|
479
|
+
past_value_state,
|
480
|
+
cache_position[b_idx][0],
|
481
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
482
|
+
)
|
483
|
+
|
484
|
+
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
485
|
+
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
486
|
+
|
487
|
+
all_key_states.append(key_state.squeeze(2))
|
488
|
+
all_value_states.append(value_state.squeeze(2))
|
489
|
+
all_attn_output.append(attn_output)
|
490
|
+
|
491
|
+
key_states = torch.cat(all_key_states, dim=0)
|
492
|
+
value_states = torch.cat(all_value_states, dim=0)
|
493
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
494
|
+
|
495
|
+
attn_output = self.out_proj(attn_output)
|
496
|
+
present_key_value = (key_states, value_states)
|
497
|
+
|
498
|
+
return attn_output, present_key_value
|
@@ -34,9 +34,9 @@ from transformers import (
|
|
34
34
|
)
|
35
35
|
from transformers.modeling_outputs import BaseModelOutput
|
36
36
|
|
37
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
37
38
|
from ....modeling import RBLNModel
|
38
39
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
39
|
-
from ....modeling_diffusers import RBLNDiffusionMixin
|
40
40
|
from ....utils.logging import get_logger
|
41
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
42
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
@@ -192,7 +192,10 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
192
192
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
193
193
|
@classmethod
|
194
194
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
195
|
-
|
195
|
+
enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
|
196
|
+
dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
|
197
|
+
|
198
|
+
return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
|
196
199
|
|
197
200
|
def __getattr__(self, __name: str) -> Any:
|
198
201
|
def redirect(func):
|