optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +2 -1
- optimum/rbln/ops/attn.py +9 -7
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +4 -3
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +20 -17
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +14 -14
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -210
- optimum/rbln/transformers/models/t5/t5_architecture.py +9 -3
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +24 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +422 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +341 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +98 -47
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +71 -26
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/METADATA +5 -5
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/RECORD +21 -17
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4a1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -73,6 +73,7 @@ _import_structure = {
|
|
73
73
|
"RBLNRobertaForMaskedLM",
|
74
74
|
"RBLNViTForImageClassification",
|
75
75
|
"RBLNBertForMaskedLM",
|
76
|
+
"RBLNTimeSeriesTransformerForPrediction",
|
76
77
|
],
|
77
78
|
"diffusers": [
|
78
79
|
"RBLNAutoencoderKL",
|
@@ -184,6 +185,7 @@ if TYPE_CHECKING:
|
|
184
185
|
RBLNRobertaForSequenceClassification,
|
185
186
|
RBLNT5EncoderModel,
|
186
187
|
RBLNT5ForConditionalGeneration,
|
188
|
+
RBLNTimeSeriesTransformerForPrediction,
|
187
189
|
RBLNViTForImageClassification,
|
188
190
|
RBLNWav2Vec2ForCTC,
|
189
191
|
RBLNWhisperForConditionalGeneration,
|
optimum/rbln/__version__.py
CHANGED
optimum/rbln/ops/__init__.py
CHANGED
@@ -13,9 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from .attn import (
|
16
|
-
|
16
|
+
register_rbln_custom_paged_add_softmax_attention,
|
17
17
|
register_rbln_custom_paged_attention,
|
18
18
|
register_rbln_custom_paged_causal_attention,
|
19
19
|
)
|
20
20
|
from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
|
21
21
|
from .kv_cache_update import register_rbln_custom_cache_update
|
22
|
+
from .linear import linear
|
optimum/rbln/ops/attn.py
CHANGED
@@ -182,14 +182,14 @@ def register_rbln_custom_paged_causal_attention():
|
|
182
182
|
|
183
183
|
|
184
184
|
@lru_cache
|
185
|
-
def
|
185
|
+
def register_rbln_custom_paged_add_softmax_attention():
|
186
186
|
torch.library.define(
|
187
|
-
"rbln_custom_ops::
|
188
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
|
187
|
+
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
188
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
189
189
|
)
|
190
190
|
|
191
|
-
@torch.library.impl("rbln_custom_ops::
|
192
|
-
def
|
191
|
+
@torch.library.impl("rbln_custom_ops::paged_add_softmax_attn_decode", "cpu")
|
192
|
+
def paged_add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
|
193
193
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
194
194
|
|
195
195
|
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
@@ -210,12 +210,14 @@ def register_rbln_custom_add_softmax_attention():
|
|
210
210
|
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
211
211
|
- seq: [1] - Current sequence position
|
212
212
|
- scale: [] - Attention scale factor
|
213
|
+
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
214
|
+
- block_size: [] - Number of tokens per block
|
213
215
|
|
214
216
|
Returns:
|
215
217
|
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
216
218
|
"""
|
217
219
|
return q
|
218
220
|
|
219
|
-
@register_fake("rbln_custom_ops::
|
220
|
-
def
|
221
|
+
@register_fake("rbln_custom_ops::paged_add_softmax_attn_decode")
|
222
|
+
def paged_add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition, block_table, block_size):
|
221
223
|
return q
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import Tensor
|
19
|
+
|
20
|
+
|
21
|
+
@torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
|
22
|
+
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
23
|
+
output_shape = list(input.shape[:-1])
|
24
|
+
output_shape += [weight.shape[0]]
|
25
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
@@ -52,6 +52,7 @@ _import_structure = {
|
|
52
52
|
"RBLNPhiForCausalLM",
|
53
53
|
"RBLNT5EncoderModel",
|
54
54
|
"RBLNT5ForConditionalGeneration",
|
55
|
+
"RBLNTimeSeriesTransformerForPrediction",
|
55
56
|
"RBLNLlavaNextForConditionalGeneration",
|
56
57
|
"RBLNMidmLMHeadModel",
|
57
58
|
"RBLNXLMRobertaModel",
|
@@ -113,6 +114,7 @@ if TYPE_CHECKING:
|
|
113
114
|
RBLNQwen2ForCausalLM,
|
114
115
|
RBLNT5EncoderModel,
|
115
116
|
RBLNT5ForConditionalGeneration,
|
117
|
+
RBLNTimeSeriesTransformerForPrediction,
|
116
118
|
RBLNWav2Vec2ForCTC,
|
117
119
|
RBLNWhisperForConditionalGeneration,
|
118
120
|
RBLNXLMRobertaModel,
|
@@ -50,6 +50,7 @@ _import_structure = {
|
|
50
50
|
"mistral": ["RBLNMistralForCausalLM"],
|
51
51
|
"phi": ["RBLNPhiForCausalLM"],
|
52
52
|
"qwen2": ["RBLNQwen2ForCausalLM"],
|
53
|
+
"time_series_transformers": ["RBLNTimeSeriesTransformerForPrediction"],
|
53
54
|
"t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
|
54
55
|
"wav2vec2": ["RBLNWav2Vec2ForCTC"],
|
55
56
|
"whisper": ["RBLNWhisperForConditionalGeneration"],
|
@@ -90,6 +91,7 @@ if TYPE_CHECKING:
|
|
90
91
|
from .phi import RBLNPhiForCausalLM
|
91
92
|
from .qwen2 import RBLNQwen2ForCausalLM
|
92
93
|
from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
94
|
+
from .time_series_transformers import RBLNTimeSeriesTransformerForPrediction
|
93
95
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
94
96
|
from .whisper import RBLNWhisperForConditionalGeneration
|
95
97
|
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -94,12 +94,11 @@ class RBLNBartModel(RBLNModel):
|
|
94
94
|
for model_input_name in rbln_model_input_names
|
95
95
|
]
|
96
96
|
|
97
|
-
|
98
|
-
dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
|
97
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
99
98
|
|
100
99
|
rbln_config = RBLNConfig(
|
101
100
|
rbln_cls=cls.__name__,
|
102
|
-
compile_cfgs=[
|
101
|
+
compile_cfgs=[rbln_compile_config],
|
103
102
|
rbln_kwargs=rbln_kwargs,
|
104
103
|
)
|
105
104
|
|
@@ -108,6 +107,8 @@ class RBLNBartModel(RBLNModel):
|
|
108
107
|
|
109
108
|
|
110
109
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
110
|
+
support_causal_attn = True
|
111
|
+
|
111
112
|
@classmethod
|
112
113
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
113
114
|
enc_max_seq_len = (
|
@@ -222,8 +222,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
222
222
|
|
223
223
|
attention_mask = self.dec_attn_mask
|
224
224
|
|
225
|
-
attention_mask = self.dec_attn_mask
|
226
|
-
|
227
225
|
logits = super().forward(
|
228
226
|
inputs,
|
229
227
|
cache_position,
|
@@ -547,22 +545,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
547
545
|
|
548
546
|
@QuantizationManager.with_quantization_env
|
549
547
|
def compile_model(*args, **kwargs):
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
548
|
+
try:
|
549
|
+
original_linear = torch.nn.functional.linear
|
550
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
551
|
+
wrapped_model.phase = "prefill"
|
552
|
+
compiled_prefill = RBLNModel.compile(
|
553
|
+
wrapped_model,
|
554
|
+
prefill_compile_config,
|
555
|
+
example_inputs=prefill_example_inputs,
|
556
|
+
compile_context=context,
|
557
|
+
)
|
557
558
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
559
|
+
wrapped_model.phase = "decode"
|
560
|
+
compiled_decoder = RBLNModel.compile(
|
561
|
+
wrapped_model,
|
562
|
+
dec_compile_config,
|
563
|
+
example_inputs=dec_example_inputs,
|
564
|
+
compile_context=context,
|
565
|
+
)
|
566
|
+
return {"prefill": compiled_prefill, "decoder": compiled_decoder}
|
567
|
+
finally:
|
568
|
+
torch.nn.functional.linear = original_linear
|
566
569
|
|
567
570
|
return compile_model(quantize_config=quantize_config)
|
568
571
|
|
@@ -38,8 +38,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
38
38
|
mandatory_members = ["main_input_name"]
|
39
39
|
|
40
40
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
41
|
-
|
42
|
-
return BaseModelOutput(last_hidden_state=
|
41
|
+
output = super().forward(*args, **kwargs)
|
42
|
+
return BaseModelOutput(last_hidden_state=output)
|
43
43
|
|
44
44
|
|
45
45
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
@@ -94,7 +94,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
94
94
|
decoder_attention_mask if self.use_attention_mask else None,
|
95
95
|
attention_mask,
|
96
96
|
cache_position,
|
97
|
-
block_tables,
|
97
|
+
block_tables=block_tables,
|
98
98
|
)
|
99
99
|
|
100
100
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -115,6 +115,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
115
115
|
|
116
116
|
main_input_name = "input_ids"
|
117
117
|
auto_model_class = AutoModelForSeq2SeqLM
|
118
|
+
support_causal_attn = None
|
118
119
|
|
119
120
|
def __post_init__(self, **kwargs):
|
120
121
|
batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -186,13 +187,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
186
187
|
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
187
188
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
188
189
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
189
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
190
190
|
|
191
|
-
if
|
192
|
-
rbln_use_attention_mask =
|
193
|
-
|
194
|
-
|
195
|
-
|
191
|
+
if cls.support_causal_attn:
|
192
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
193
|
+
if rbln_use_attention_mask is None:
|
194
|
+
rbln_use_attention_mask = False
|
195
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
196
|
+
if rbln_npu == "RBLN-CA02":
|
197
|
+
rbln_use_attention_mask = True
|
198
|
+
else:
|
199
|
+
rbln_use_attention_mask = True
|
196
200
|
|
197
201
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
198
202
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -265,11 +269,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
265
269
|
[rbln_batch_size, 1],
|
266
270
|
"int32",
|
267
271
|
),
|
268
|
-
(
|
269
|
-
"block_tables",
|
270
|
-
[rbln_batch_size, 1],
|
271
|
-
"int16",
|
272
|
-
),
|
272
|
+
("block_tables", [rbln_batch_size, 1], "int16"),
|
273
273
|
]
|
274
274
|
dec_input_info.extend(
|
275
275
|
[
|
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
17
17
|
|
18
|
-
import rebel
|
19
18
|
import torch
|
20
19
|
from transformers import (
|
21
20
|
AutoModelForTextEncoding,
|
@@ -23,7 +22,7 @@ from transformers import (
|
|
23
22
|
T5EncoderModel,
|
24
23
|
T5ForConditionalGeneration,
|
25
24
|
)
|
26
|
-
from transformers.modeling_outputs import BaseModelOutput
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput
|
27
26
|
|
28
27
|
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
29
28
|
from ....modeling import RBLNModel
|
@@ -58,63 +57,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
58
57
|
)
|
59
58
|
|
60
59
|
|
61
|
-
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
62
|
-
mandatory_members = ["main_input_name"]
|
63
|
-
|
64
|
-
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
65
|
-
_ = super().forward(*args, **kwargs)
|
66
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
67
|
-
|
68
|
-
|
69
|
-
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
70
|
-
mandatory_members = ["main_input_name"]
|
71
|
-
|
72
|
-
def __init__(
|
73
|
-
self,
|
74
|
-
runtime: rebel.Runtime,
|
75
|
-
batch_size: int,
|
76
|
-
dec_max_seq_len: int,
|
77
|
-
**kwargs: Any,
|
78
|
-
) -> None:
|
79
|
-
super().__init__(runtime, **kwargs)
|
80
|
-
self.batch_size = batch_size
|
81
|
-
self.dec_max_seq_len = dec_max_seq_len
|
82
|
-
|
83
|
-
def forward(
|
84
|
-
self,
|
85
|
-
decoder_input_ids: Optional[torch.LongTensor] = None,
|
86
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
87
|
-
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
88
|
-
cache_position: Optional[torch.Tensor] = None,
|
89
|
-
**kwargs,
|
90
|
-
) -> Tuple[torch.FloatTensor]:
|
91
|
-
batch_size = decoder_input_ids.shape[0]
|
92
|
-
if batch_size != self.batch_size:
|
93
|
-
raise RuntimeError(
|
94
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
95
|
-
)
|
96
|
-
|
97
|
-
if batch_size != cache_position.shape[0]:
|
98
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
99
|
-
|
100
|
-
for b_idx in range(self.batch_size):
|
101
|
-
decoding_step = cache_position[b_idx].item()
|
102
|
-
if not (0 <= decoding_step < self.dec_max_seq_len):
|
103
|
-
raise ValueError(
|
104
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
105
|
-
)
|
106
|
-
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
107
|
-
|
108
|
-
lm_logits = super().forward(
|
109
|
-
decoder_input_ids,
|
110
|
-
decoder_attention_mask,
|
111
|
-
attention_mask,
|
112
|
-
cache_position,
|
113
|
-
)
|
114
|
-
|
115
|
-
return Seq2SeqLMOutput(logits=lm_logits)
|
116
|
-
|
117
|
-
|
118
60
|
class T5EncoderWrapper(torch.nn.Module):
|
119
61
|
def __init__(self, model: "T5EncoderModel") -> None:
|
120
62
|
super().__init__()
|
@@ -247,20 +189,7 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
247
189
|
|
248
190
|
|
249
191
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
250
|
-
|
251
|
-
batch_size = self.rbln_config.model_cfg["batch_size"]
|
252
|
-
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
253
|
-
|
254
|
-
self.encoder = RBLNRuntimeEncoder(
|
255
|
-
runtime=self.model[0],
|
256
|
-
main_input_name="input_ids",
|
257
|
-
)
|
258
|
-
self.decoder = RBLNRuntimeDecoder(
|
259
|
-
runtime=self.model[1],
|
260
|
-
main_input_name="input_ids",
|
261
|
-
batch_size=batch_size,
|
262
|
-
dec_max_seq_len=dec_max_seq_len,
|
263
|
-
)
|
192
|
+
support_causal_attn = False
|
264
193
|
|
265
194
|
@classmethod
|
266
195
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
@@ -279,139 +208,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
279
208
|
return redirect(val)
|
280
209
|
|
281
210
|
return val
|
282
|
-
|
283
|
-
@classmethod
|
284
|
-
def _get_rbln_config(
|
285
|
-
cls,
|
286
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
287
|
-
model_config: "PretrainedConfig",
|
288
|
-
rbln_kwargs: Dict[str, Any] = {},
|
289
|
-
) -> RBLNConfig:
|
290
|
-
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
291
|
-
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
292
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
293
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
294
|
-
|
295
|
-
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
296
|
-
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
297
|
-
d_kv = (
|
298
|
-
model_config.d_kv
|
299
|
-
if hasattr(model_config, "d_kv")
|
300
|
-
else model_config.d_model // model_config.encoder_attention_heads
|
301
|
-
)
|
302
|
-
|
303
|
-
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
304
|
-
model_config, "max_position_embeddings", None
|
305
|
-
)
|
306
|
-
|
307
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
308
|
-
if rbln_pad_token_id is None:
|
309
|
-
rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
|
310
|
-
if rbln_pad_token_id is None:
|
311
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
312
|
-
if rbln_pad_token_id is None:
|
313
|
-
rbln_pad_token_id = -1
|
314
|
-
|
315
|
-
if rbln_enc_max_seq_len is None:
|
316
|
-
rbln_enc_max_seq_len = max_position_embeddings
|
317
|
-
if rbln_enc_max_seq_len is None:
|
318
|
-
for tokenizer in preprocessors:
|
319
|
-
if hasattr(tokenizer, "model_max_length"):
|
320
|
-
rbln_enc_max_seq_len = tokenizer.model_max_length
|
321
|
-
break
|
322
|
-
if rbln_enc_max_seq_len is None:
|
323
|
-
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
324
|
-
if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
|
325
|
-
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
326
|
-
|
327
|
-
if rbln_dec_max_seq_len is None:
|
328
|
-
rbln_dec_max_seq_len = max_position_embeddings
|
329
|
-
if rbln_dec_max_seq_len is None:
|
330
|
-
for tokenizer in preprocessors:
|
331
|
-
if hasattr(tokenizer, "model_max_length"):
|
332
|
-
rbln_dec_max_seq_len = tokenizer.model_max_length
|
333
|
-
break
|
334
|
-
if rbln_dec_max_seq_len is None:
|
335
|
-
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
336
|
-
|
337
|
-
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
338
|
-
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
339
|
-
|
340
|
-
# model input info
|
341
|
-
enc_input_info = [
|
342
|
-
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
343
|
-
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
344
|
-
(
|
345
|
-
"cross_key_value_states",
|
346
|
-
[
|
347
|
-
n_layer * 2,
|
348
|
-
rbln_batch_size,
|
349
|
-
n_head,
|
350
|
-
rbln_enc_max_seq_len,
|
351
|
-
d_kv,
|
352
|
-
],
|
353
|
-
"float32",
|
354
|
-
),
|
355
|
-
("block_tables", [1], "int16"),
|
356
|
-
]
|
357
|
-
|
358
|
-
dec_input_info = [
|
359
|
-
("input_ids", [rbln_batch_size, 1], "int64"),
|
360
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
361
|
-
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
362
|
-
(
|
363
|
-
"cache_position",
|
364
|
-
[rbln_batch_size, 1],
|
365
|
-
"int32",
|
366
|
-
),
|
367
|
-
]
|
368
|
-
dec_input_info.extend(
|
369
|
-
[
|
370
|
-
(
|
371
|
-
"cross_key_value_states",
|
372
|
-
[
|
373
|
-
n_layer * 2,
|
374
|
-
rbln_batch_size,
|
375
|
-
n_head,
|
376
|
-
rbln_enc_max_seq_len,
|
377
|
-
d_kv,
|
378
|
-
],
|
379
|
-
"float32",
|
380
|
-
)
|
381
|
-
]
|
382
|
-
)
|
383
|
-
dec_input_info.extend(
|
384
|
-
[
|
385
|
-
(
|
386
|
-
f"self_key_value_states_{i}",
|
387
|
-
[
|
388
|
-
rbln_batch_size,
|
389
|
-
n_head,
|
390
|
-
rbln_dec_max_seq_len,
|
391
|
-
d_kv,
|
392
|
-
],
|
393
|
-
"float32",
|
394
|
-
)
|
395
|
-
for i in range(n_layer * 2)
|
396
|
-
]
|
397
|
-
)
|
398
|
-
|
399
|
-
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
400
|
-
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
401
|
-
|
402
|
-
rbln_config = RBLNConfig(
|
403
|
-
rbln_cls=cls.__name__,
|
404
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
405
|
-
rbln_kwargs=rbln_kwargs,
|
406
|
-
)
|
407
|
-
|
408
|
-
rbln_config.model_cfg.update(
|
409
|
-
{
|
410
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
411
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
412
|
-
"batch_size": rbln_batch_size,
|
413
|
-
"pad_token_id": rbln_pad_token_id,
|
414
|
-
}
|
415
|
-
)
|
416
|
-
|
417
|
-
return rbln_config
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import register_rbln_custom_paged_add_softmax_attention
|
22
22
|
from ..seq2seq.seq2seq_architecture import (
|
23
23
|
Seq2SeqDecoder,
|
24
24
|
Seq2SeqDecoderLayer,
|
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
55
55
|
|
56
56
|
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
57
57
|
def __post_init__(self, model, dec_max_seq_len: int = None):
|
58
|
-
|
58
|
+
register_rbln_custom_paged_add_softmax_attention()
|
59
59
|
self.num_layers = self.config.num_layers
|
60
60
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
61
61
|
|
@@ -77,6 +77,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
77
77
|
attention_mask,
|
78
78
|
encoder_attention_mask,
|
79
79
|
cache_position,
|
80
|
+
block_tables,
|
80
81
|
cross_kv_cache,
|
81
82
|
*self_kv_cache,
|
82
83
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
@@ -95,6 +96,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
95
96
|
self_past_key_values=self_past_key_values,
|
96
97
|
cross_past_key_values=cross_past_key_values,
|
97
98
|
cache_position=cache_position,
|
99
|
+
block_tables=block_tables,
|
98
100
|
)
|
99
101
|
|
100
102
|
return lm_logits
|
@@ -162,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
162
164
|
self.out_proj = self._original_mod.o
|
163
165
|
self.num_heads = self._original_mod.n_heads
|
164
166
|
self.head_dim = self._original_mod.key_value_proj_dim
|
165
|
-
self.attn_decode = torch.ops.rbln_custom_ops.
|
167
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
166
168
|
|
167
169
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
168
170
|
query_states = self.q_proj(hidden_states)
|
@@ -176,6 +178,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
176
178
|
past_key_value: Tuple[torch.Tensor],
|
177
179
|
attention_mask: torch.Tensor,
|
178
180
|
cache_position: torch.Tensor,
|
181
|
+
block_tables: torch.Tensor,
|
179
182
|
**kwargs,
|
180
183
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
181
184
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -185,6 +188,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
185
188
|
key_states = self._shape(key_states, -1, bsz)
|
186
189
|
value_states = self._shape(value_states, -1, bsz)
|
187
190
|
|
191
|
+
block_size = past_key_value[0].shape[-2]
|
188
192
|
attn_output = self.attn_decode(
|
189
193
|
query_states,
|
190
194
|
key_states,
|
@@ -196,6 +200,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
196
200
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
197
201
|
cache_position,
|
198
202
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
203
|
+
block_tables,
|
204
|
+
block_size,
|
199
205
|
)
|
200
206
|
|
201
207
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
|