optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +4 -4
- optimum/rbln/ops/attn.py +44 -84
- optimum/rbln/ops/flash_attn.py +25 -25
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +79 -51
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +157 -34
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +7 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +7 -2
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -3
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +44 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +50 -19
- optimum/rbln/transformers/models/t5/modeling_t5.py +211 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +69 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +19 -24
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/RECORD +22 -22
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/licenses/LICENSE +0 -0
@@ -12,13 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import (
|
22
|
+
register_rbln_custom_cache_update,
|
23
|
+
register_rbln_custom_paged_attention,
|
24
|
+
register_rbln_custom_paged_causal_attention,
|
25
|
+
)
|
22
26
|
|
23
27
|
|
24
28
|
logger = logging.get_logger(__name__)
|
@@ -87,7 +91,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
87
91
|
input_ids: torch.Tensor,
|
88
92
|
attention_mask: torch.Tensor,
|
89
93
|
cross_key_values: torch.Tensor,
|
90
|
-
|
94
|
+
b_idx: torch.Tensor,
|
91
95
|
) -> Tuple[torch.Tensor]:
|
92
96
|
# 1. get encoder last_hidden_states
|
93
97
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
@@ -111,7 +115,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
111
115
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
112
116
|
batch_axis = torch.tensor(1, dtype=torch.int16)
|
113
117
|
cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
|
114
|
-
cross_key_values, cross_kv,
|
118
|
+
cross_key_values, cross_kv, b_idx[0], batch_axis
|
115
119
|
)
|
116
120
|
|
117
121
|
return cross_key_values
|
@@ -131,9 +135,10 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
131
135
|
**kwargs: Additional arguments for decoder configuration.
|
132
136
|
"""
|
133
137
|
|
134
|
-
def __init__(self, model: nn.Module, **kwargs):
|
138
|
+
def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
|
135
139
|
super().__init__()
|
136
140
|
self.config = model.config
|
141
|
+
self.use_attention_mask = use_attention_mask
|
137
142
|
self.__post_init__(model, **kwargs)
|
138
143
|
|
139
144
|
def __post_init__(self, model: nn.Module, **kwargs):
|
@@ -143,7 +148,11 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
143
148
|
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
144
149
|
by subclasses to modify or add custom attributes as necessary.
|
145
150
|
"""
|
146
|
-
|
151
|
+
if self.use_attention_mask:
|
152
|
+
register_rbln_custom_paged_attention()
|
153
|
+
else:
|
154
|
+
register_rbln_custom_paged_causal_attention()
|
155
|
+
|
147
156
|
self.num_layers = self.config.decoder_layers
|
148
157
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
149
158
|
|
@@ -160,13 +169,23 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
160
169
|
|
161
170
|
def forward(
|
162
171
|
self,
|
163
|
-
|
164
|
-
attention_mask: torch.Tensor,
|
165
|
-
encoder_attention_mask: torch.Tensor,
|
166
|
-
cache_position: torch.Tensor,
|
167
|
-
cross_kv_cache: torch.Tensor,
|
168
|
-
*self_kv_cache: torch.Tensor,
|
172
|
+
*args,
|
169
173
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
174
|
+
if self.use_attention_mask:
|
175
|
+
(
|
176
|
+
input_ids,
|
177
|
+
attention_mask,
|
178
|
+
encoder_attention_mask,
|
179
|
+
cache_position,
|
180
|
+
block_tables,
|
181
|
+
cross_kv_cache,
|
182
|
+
*self_kv_cache,
|
183
|
+
) = args
|
184
|
+
|
185
|
+
else:
|
186
|
+
attention_mask = None
|
187
|
+
(input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
|
188
|
+
|
170
189
|
self_past_key_values = ()
|
171
190
|
cross_past_key_values = ()
|
172
191
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -181,6 +200,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
181
200
|
self_past_key_values=self_past_key_values,
|
182
201
|
cross_past_key_values=cross_past_key_values,
|
183
202
|
cache_position=cache_position,
|
203
|
+
block_tables=block_tables,
|
184
204
|
)
|
185
205
|
|
186
206
|
outputs = (lm_logits,) + self_present_key_values
|
@@ -228,6 +248,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
228
248
|
self_past_key_values,
|
229
249
|
cross_past_key_values,
|
230
250
|
cache_position,
|
251
|
+
block_tables: Optional[torch.Tensor] = None,
|
231
252
|
):
|
232
253
|
hidden_states, self_present_key_values = self.decoder(
|
233
254
|
input_ids=input_ids,
|
@@ -236,6 +257,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
236
257
|
self_past_key_values=self_past_key_values,
|
237
258
|
cross_past_key_values=cross_past_key_values,
|
238
259
|
cache_position=cache_position,
|
260
|
+
block_tables=block_tables,
|
239
261
|
)
|
240
262
|
|
241
263
|
if self.has_rescaling and self.config.tie_word_embeddings:
|
@@ -292,6 +314,7 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
292
314
|
self_past_key_values: torch.Tensor,
|
293
315
|
cross_past_key_values: torch.Tensor,
|
294
316
|
cache_position: torch.Tensor,
|
317
|
+
block_tables: Optional[torch.Tensor] = None,
|
295
318
|
):
|
296
319
|
# embedding
|
297
320
|
hidden_states = self.get_embedding()(input_ids)
|
@@ -314,6 +337,7 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
314
337
|
self_past_key_value=self_past_key_value,
|
315
338
|
cross_past_key_value=cross_past_key_value,
|
316
339
|
cache_position=cache_position,
|
340
|
+
block_tables=block_tables,
|
317
341
|
)
|
318
342
|
self_present_key_values += self_present_key_value
|
319
343
|
|
@@ -373,6 +397,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
373
397
|
self_past_key_value: Tuple[torch.Tensor],
|
374
398
|
cross_past_key_value: Tuple[torch.Tensor],
|
375
399
|
cache_position: torch.Tensor,
|
400
|
+
block_tables: Optional[torch.Tensor] = None,
|
376
401
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
377
402
|
dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
|
378
403
|
|
@@ -384,6 +409,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
384
409
|
past_key_value=self_past_key_value,
|
385
410
|
attention_mask=attention_mask,
|
386
411
|
cache_position=cache_position,
|
412
|
+
block_tables=block_tables,
|
387
413
|
)
|
388
414
|
hidden_states = residual + hidden_states
|
389
415
|
hidden_states = self.post_self_attn_layer_norm(hidden_states)
|
@@ -407,10 +433,10 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
407
433
|
|
408
434
|
|
409
435
|
class Seq2SeqSelfAttention(nn.Module):
|
410
|
-
def __init__(self, attn):
|
436
|
+
def __init__(self, attn, **kwargs):
|
411
437
|
super().__init__()
|
412
438
|
self._original_mod = attn
|
413
|
-
self.__post_init__()
|
439
|
+
self.__post_init__(**kwargs)
|
414
440
|
|
415
441
|
def __post_init__(self, **kwargs):
|
416
442
|
"""
|
@@ -442,6 +468,7 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
442
468
|
past_key_value: Tuple[torch.Tensor],
|
443
469
|
attention_mask: torch.Tensor,
|
444
470
|
cache_position: torch.Tensor,
|
471
|
+
block_tables: Optional[torch.Tensor] = None,
|
445
472
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
446
473
|
bsz, tgt_len, _ = hidden_states.size()
|
447
474
|
|
@@ -450,18 +477,22 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
450
477
|
key_states = self._shape(key_states, -1, bsz)
|
451
478
|
value_states = self._shape(value_states, -1, bsz)
|
452
479
|
|
453
|
-
|
480
|
+
block_size = past_key_value[0].shape[-2]
|
481
|
+
args = [
|
454
482
|
query_states,
|
455
483
|
key_states,
|
456
484
|
value_states,
|
457
|
-
attention_mask.unsqueeze(
|
458
|
-
2
|
459
|
-
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
460
485
|
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
461
486
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
462
487
|
cache_position,
|
463
488
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
464
|
-
|
489
|
+
block_tables,
|
490
|
+
block_size,
|
491
|
+
]
|
492
|
+
if attention_mask is not None:
|
493
|
+
args.insert(3, attention_mask.unsqueeze(2))
|
494
|
+
|
495
|
+
attn_output, key_states, value_states = self.attn_decode(*args)
|
465
496
|
|
466
497
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
467
498
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
@@ -13,8 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
17
17
|
|
18
|
+
import rebel
|
18
19
|
import torch
|
19
20
|
from transformers import (
|
20
21
|
AutoModelForTextEncoding,
|
@@ -22,7 +23,7 @@ from transformers import (
|
|
22
23
|
T5EncoderModel,
|
23
24
|
T5ForConditionalGeneration,
|
24
25
|
)
|
25
|
-
from transformers.modeling_outputs import BaseModelOutput
|
26
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
26
27
|
|
27
28
|
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
28
29
|
from ....modeling import RBLNModel
|
@@ -57,6 +58,63 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
57
58
|
)
|
58
59
|
|
59
60
|
|
61
|
+
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
62
|
+
mandatory_members = ["main_input_name"]
|
63
|
+
|
64
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
65
|
+
_ = super().forward(*args, **kwargs)
|
66
|
+
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
67
|
+
|
68
|
+
|
69
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
70
|
+
mandatory_members = ["main_input_name"]
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
runtime: rebel.Runtime,
|
75
|
+
batch_size: int,
|
76
|
+
dec_max_seq_len: int,
|
77
|
+
**kwargs: Any,
|
78
|
+
) -> None:
|
79
|
+
super().__init__(runtime, **kwargs)
|
80
|
+
self.batch_size = batch_size
|
81
|
+
self.dec_max_seq_len = dec_max_seq_len
|
82
|
+
|
83
|
+
def forward(
|
84
|
+
self,
|
85
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
86
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
87
|
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
88
|
+
cache_position: Optional[torch.Tensor] = None,
|
89
|
+
**kwargs,
|
90
|
+
) -> Tuple[torch.FloatTensor]:
|
91
|
+
batch_size = decoder_input_ids.shape[0]
|
92
|
+
if batch_size != self.batch_size:
|
93
|
+
raise RuntimeError(
|
94
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
95
|
+
)
|
96
|
+
|
97
|
+
if batch_size != cache_position.shape[0]:
|
98
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
99
|
+
|
100
|
+
for b_idx in range(self.batch_size):
|
101
|
+
decoding_step = cache_position[b_idx].item()
|
102
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
103
|
+
raise ValueError(
|
104
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
105
|
+
)
|
106
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
107
|
+
|
108
|
+
lm_logits = super().forward(
|
109
|
+
decoder_input_ids,
|
110
|
+
decoder_attention_mask,
|
111
|
+
attention_mask,
|
112
|
+
cache_position,
|
113
|
+
)
|
114
|
+
|
115
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
116
|
+
|
117
|
+
|
60
118
|
class T5EncoderWrapper(torch.nn.Module):
|
61
119
|
def __init__(self, model: "T5EncoderModel") -> None:
|
62
120
|
super().__init__()
|
@@ -189,6 +247,21 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
189
247
|
|
190
248
|
|
191
249
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
250
|
+
def __post_init__(self, **kwargs):
|
251
|
+
batch_size = self.rbln_config.model_cfg["batch_size"]
|
252
|
+
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
253
|
+
|
254
|
+
self.encoder = RBLNRuntimeEncoder(
|
255
|
+
runtime=self.model[0],
|
256
|
+
main_input_name="input_ids",
|
257
|
+
)
|
258
|
+
self.decoder = RBLNRuntimeDecoder(
|
259
|
+
runtime=self.model[1],
|
260
|
+
main_input_name="input_ids",
|
261
|
+
batch_size=batch_size,
|
262
|
+
dec_max_seq_len=dec_max_seq_len,
|
263
|
+
)
|
264
|
+
|
192
265
|
@classmethod
|
193
266
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
194
267
|
enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
|
@@ -206,3 +279,139 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
206
279
|
return redirect(val)
|
207
280
|
|
208
281
|
return val
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def _get_rbln_config(
|
285
|
+
cls,
|
286
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
287
|
+
model_config: "PretrainedConfig",
|
288
|
+
rbln_kwargs: Dict[str, Any] = {},
|
289
|
+
) -> RBLNConfig:
|
290
|
+
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
291
|
+
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
292
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
293
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
294
|
+
|
295
|
+
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
296
|
+
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
297
|
+
d_kv = (
|
298
|
+
model_config.d_kv
|
299
|
+
if hasattr(model_config, "d_kv")
|
300
|
+
else model_config.d_model // model_config.encoder_attention_heads
|
301
|
+
)
|
302
|
+
|
303
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
304
|
+
model_config, "max_position_embeddings", None
|
305
|
+
)
|
306
|
+
|
307
|
+
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
308
|
+
if rbln_pad_token_id is None:
|
309
|
+
rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
|
310
|
+
if rbln_pad_token_id is None:
|
311
|
+
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
312
|
+
if rbln_pad_token_id is None:
|
313
|
+
rbln_pad_token_id = -1
|
314
|
+
|
315
|
+
if rbln_enc_max_seq_len is None:
|
316
|
+
rbln_enc_max_seq_len = max_position_embeddings
|
317
|
+
if rbln_enc_max_seq_len is None:
|
318
|
+
for tokenizer in preprocessors:
|
319
|
+
if hasattr(tokenizer, "model_max_length"):
|
320
|
+
rbln_enc_max_seq_len = tokenizer.model_max_length
|
321
|
+
break
|
322
|
+
if rbln_enc_max_seq_len is None:
|
323
|
+
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
324
|
+
if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
|
325
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
326
|
+
|
327
|
+
if rbln_dec_max_seq_len is None:
|
328
|
+
rbln_dec_max_seq_len = max_position_embeddings
|
329
|
+
if rbln_dec_max_seq_len is None:
|
330
|
+
for tokenizer in preprocessors:
|
331
|
+
if hasattr(tokenizer, "model_max_length"):
|
332
|
+
rbln_dec_max_seq_len = tokenizer.model_max_length
|
333
|
+
break
|
334
|
+
if rbln_dec_max_seq_len is None:
|
335
|
+
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
336
|
+
|
337
|
+
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
338
|
+
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
339
|
+
|
340
|
+
# model input info
|
341
|
+
enc_input_info = [
|
342
|
+
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
343
|
+
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
344
|
+
(
|
345
|
+
"cross_key_value_states",
|
346
|
+
[
|
347
|
+
n_layer * 2,
|
348
|
+
rbln_batch_size,
|
349
|
+
n_head,
|
350
|
+
rbln_enc_max_seq_len,
|
351
|
+
d_kv,
|
352
|
+
],
|
353
|
+
"float32",
|
354
|
+
),
|
355
|
+
("block_tables", [1], "int16"),
|
356
|
+
]
|
357
|
+
|
358
|
+
dec_input_info = [
|
359
|
+
("input_ids", [rbln_batch_size, 1], "int64"),
|
360
|
+
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
361
|
+
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
362
|
+
(
|
363
|
+
"cache_position",
|
364
|
+
[rbln_batch_size, 1],
|
365
|
+
"int32",
|
366
|
+
),
|
367
|
+
]
|
368
|
+
dec_input_info.extend(
|
369
|
+
[
|
370
|
+
(
|
371
|
+
"cross_key_value_states",
|
372
|
+
[
|
373
|
+
n_layer * 2,
|
374
|
+
rbln_batch_size,
|
375
|
+
n_head,
|
376
|
+
rbln_enc_max_seq_len,
|
377
|
+
d_kv,
|
378
|
+
],
|
379
|
+
"float32",
|
380
|
+
)
|
381
|
+
]
|
382
|
+
)
|
383
|
+
dec_input_info.extend(
|
384
|
+
[
|
385
|
+
(
|
386
|
+
f"self_key_value_states_{i}",
|
387
|
+
[
|
388
|
+
rbln_batch_size,
|
389
|
+
n_head,
|
390
|
+
rbln_dec_max_seq_len,
|
391
|
+
d_kv,
|
392
|
+
],
|
393
|
+
"float32",
|
394
|
+
)
|
395
|
+
for i in range(n_layer * 2)
|
396
|
+
]
|
397
|
+
)
|
398
|
+
|
399
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
400
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
401
|
+
|
402
|
+
rbln_config = RBLNConfig(
|
403
|
+
rbln_cls=cls.__name__,
|
404
|
+
compile_cfgs=[enc_compile_config, dec_compile_config],
|
405
|
+
rbln_kwargs=rbln_kwargs,
|
406
|
+
)
|
407
|
+
|
408
|
+
rbln_config.model_cfg.update(
|
409
|
+
{
|
410
|
+
"enc_max_seq_len": rbln_enc_max_seq_len,
|
411
|
+
"dec_max_seq_len": rbln_dec_max_seq_len,
|
412
|
+
"batch_size": rbln_batch_size,
|
413
|
+
"pad_token_id": rbln_pad_token_id,
|
414
|
+
}
|
415
|
+
)
|
416
|
+
|
417
|
+
return rbln_config
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import register_rbln_custom_add_softmax_attention
|
22
22
|
from ..seq2seq.seq2seq_architecture import (
|
23
23
|
Seq2SeqDecoder,
|
24
24
|
Seq2SeqDecoderLayer,
|
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
55
55
|
|
56
56
|
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
57
57
|
def __post_init__(self, model, dec_max_seq_len: int = None):
|
58
|
-
|
58
|
+
register_rbln_custom_add_softmax_attention()
|
59
59
|
self.num_layers = self.config.num_layers
|
60
60
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
61
61
|
|
@@ -71,6 +71,36 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
71
71
|
|
72
72
|
return new_model
|
73
73
|
|
74
|
+
def forward(
|
75
|
+
self,
|
76
|
+
input_ids,
|
77
|
+
attention_mask,
|
78
|
+
encoder_attention_mask,
|
79
|
+
cache_position,
|
80
|
+
cross_kv_cache,
|
81
|
+
*self_kv_cache,
|
82
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
83
|
+
self_past_key_values = ()
|
84
|
+
cross_past_key_values = ()
|
85
|
+
|
86
|
+
for i in range(0, self.num_layers * 2, 2):
|
87
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
88
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
89
|
+
|
90
|
+
# decode
|
91
|
+
lm_logits, self_present_key_values = self.conditional_generation(
|
92
|
+
input_ids=input_ids,
|
93
|
+
attention_mask=attention_mask,
|
94
|
+
encoder_attention_mask=encoder_attention_mask,
|
95
|
+
self_past_key_values=self_past_key_values,
|
96
|
+
cross_past_key_values=cross_past_key_values,
|
97
|
+
cache_position=cache_position,
|
98
|
+
)
|
99
|
+
|
100
|
+
outputs = (lm_logits,) + self_present_key_values
|
101
|
+
|
102
|
+
return outputs
|
103
|
+
|
74
104
|
|
75
105
|
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
76
106
|
has_rescaling = True
|
@@ -134,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
134
164
|
self.out_proj = self._original_mod.o
|
135
165
|
self.num_heads = self._original_mod.n_heads
|
136
166
|
self.head_dim = self._original_mod.key_value_proj_dim
|
137
|
-
self.attn_decode = torch.ops.rbln_custom_ops.
|
167
|
+
self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
|
138
168
|
|
139
169
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
140
170
|
query_states = self.q_proj(hidden_states)
|
@@ -142,6 +172,42 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
142
172
|
value_states = self.v_proj(hidden_states)
|
143
173
|
return query_states, key_states, value_states
|
144
174
|
|
175
|
+
def forward(
|
176
|
+
self,
|
177
|
+
hidden_states: torch.Tensor,
|
178
|
+
past_key_value: Tuple[torch.Tensor],
|
179
|
+
attention_mask: torch.Tensor,
|
180
|
+
cache_position: torch.Tensor,
|
181
|
+
**kwargs,
|
182
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
183
|
+
bsz, tgt_len, _ = hidden_states.size()
|
184
|
+
|
185
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
186
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
187
|
+
key_states = self._shape(key_states, -1, bsz)
|
188
|
+
value_states = self._shape(value_states, -1, bsz)
|
189
|
+
|
190
|
+
attn_output, key_states, value_states = self.attn_decode(
|
191
|
+
query_states,
|
192
|
+
key_states,
|
193
|
+
value_states,
|
194
|
+
attention_mask.unsqueeze(
|
195
|
+
2
|
196
|
+
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
197
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
198
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
199
|
+
cache_position,
|
200
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
201
|
+
)
|
202
|
+
|
203
|
+
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
204
|
+
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
205
|
+
|
206
|
+
attn_output = self.out_proj(attn_output)
|
207
|
+
present_key_value = (key_states, value_states)
|
208
|
+
|
209
|
+
return attn_output, present_key_value
|
210
|
+
|
145
211
|
|
146
212
|
class T5CrossAttention(nn.Module):
|
147
213
|
def __init__(self, attn):
|
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
|
|
25
25
|
)
|
26
26
|
from transformers.utils import logging
|
27
27
|
|
28
|
-
from ....ops import register_rbln_custom_cache_update
|
28
|
+
from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
|
29
29
|
|
30
30
|
|
31
31
|
logger = logging.get_logger(__name__)
|
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
|
|
34
34
|
class WhisperWrapper:
|
35
35
|
def __init__(self, model, rbln_token_timestamps):
|
36
36
|
register_rbln_custom_cache_update()
|
37
|
+
register_rbln_custom_add_softmax_attention()
|
37
38
|
self.encoder = WhisperEncoderWrapper(model)
|
38
39
|
self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
|
39
40
|
|
@@ -213,7 +214,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
213
214
|
# Self Attention Block
|
214
215
|
residual = hidden_states
|
215
216
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
216
|
-
hidden_states,
|
217
|
+
hidden_states, self_present_key_value = self.self_attn(
|
217
218
|
hidden_states=hidden_states,
|
218
219
|
past_key_value=self_past_key_value,
|
219
220
|
attention_mask=attention_mask,
|
@@ -224,7 +225,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
224
225
|
# Cross-Attention Block
|
225
226
|
residual = hidden_states
|
226
227
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
227
|
-
hidden_states, cross_attn_weights
|
228
|
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
228
229
|
hidden_states=hidden_states,
|
229
230
|
past_key_value=cross_past_key_value,
|
230
231
|
)
|
@@ -258,19 +259,8 @@ class WhisperAttention(nn.Module):
|
|
258
259
|
|
259
260
|
|
260
261
|
class WhisperSelfAttention(WhisperAttention):
|
261
|
-
def
|
262
|
-
self,
|
263
|
-
past_key_value: torch.Tensor,
|
264
|
-
key_states: torch.Tensor,
|
265
|
-
value_states: torch.Tensor,
|
266
|
-
cache_position: torch.Tensor,
|
267
|
-
):
|
268
|
-
s_idx = torch.tensor(cache_position, dtype=torch.int16)
|
269
|
-
axis = torch.tensor(2, dtype=torch.int16)
|
270
|
-
|
271
|
-
key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
|
272
|
-
value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
|
273
|
-
return key_states, value_states
|
262
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
263
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
274
264
|
|
275
265
|
def forward(
|
276
266
|
self,
|
@@ -285,22 +275,27 @@ class WhisperSelfAttention(WhisperAttention):
|
|
285
275
|
|
286
276
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
287
277
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
288
|
-
key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
|
289
278
|
|
290
|
-
|
291
|
-
|
292
|
-
|
279
|
+
attn_output, key_states, value_states = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
|
280
|
+
query_states,
|
281
|
+
key_states,
|
282
|
+
value_states,
|
283
|
+
attention_mask.unsqueeze(2),
|
284
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
285
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
286
|
+
cache_position.expand(bsz, 1),
|
287
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
288
|
+
)
|
293
289
|
|
294
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
295
290
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
296
291
|
attn_output = attn_output.transpose(1, 2)
|
297
292
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
298
293
|
attn_output = self.out_proj(attn_output)
|
299
294
|
|
300
|
-
return attn_output,
|
295
|
+
return attn_output, (key_states, value_states)
|
301
296
|
|
302
297
|
|
303
|
-
class WhisperCrossAttention(
|
298
|
+
class WhisperCrossAttention(WhisperAttention):
|
304
299
|
def forward(
|
305
300
|
self,
|
306
301
|
hidden_states: torch.Tensor,
|
@@ -322,4 +317,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
|
|
322
317
|
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
323
318
|
attn_output = self.out_proj(attn_output)
|
324
319
|
|
325
|
-
return attn_output, attn_weights
|
320
|
+
return attn_output, attn_weights
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.3a3
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|