optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -18,12 +18,6 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import (
|
22
|
-
register_rbln_custom_cache_update,
|
23
|
-
register_rbln_custom_paged_attention,
|
24
|
-
register_rbln_custom_paged_causal_attention,
|
25
|
-
)
|
26
|
-
|
27
21
|
|
28
22
|
logger = logging.get_logger(__name__)
|
29
23
|
|
@@ -59,7 +53,6 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
59
53
|
|
60
54
|
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
61
55
|
super().__init__()
|
62
|
-
register_rbln_custom_cache_update()
|
63
56
|
self.config = model.config
|
64
57
|
self.encoder = model.get_encoder()
|
65
58
|
self.encoder_max_length = enc_max_seq_len
|
@@ -90,8 +83,8 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
90
83
|
self,
|
91
84
|
input_ids: torch.Tensor,
|
92
85
|
attention_mask: torch.Tensor,
|
93
|
-
cross_key_values: torch.Tensor,
|
94
86
|
b_idx: torch.Tensor,
|
87
|
+
*cross_key_values: Tuple[torch.Tensor],
|
95
88
|
) -> Tuple[torch.Tensor]:
|
96
89
|
# 1. get encoder last_hidden_states
|
97
90
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
@@ -110,13 +103,15 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
110
103
|
cross_kv.append(past_k)
|
111
104
|
cross_kv.append(past_v)
|
112
105
|
|
113
|
-
cross_kv = torch.stack(cross_kv, dim=0)
|
114
|
-
|
115
106
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
116
|
-
batch_axis = torch.tensor(
|
117
|
-
|
107
|
+
batch_axis = torch.tensor(0, dtype=torch.int16)
|
108
|
+
cross_key_values = list(cross_key_values)
|
109
|
+
for i in range(self.n_layer * 2):
|
110
|
+
cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
|
111
|
+
cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
|
112
|
+
)
|
118
113
|
|
119
|
-
return
|
114
|
+
return cross_key_values
|
120
115
|
|
121
116
|
|
122
117
|
class Seq2SeqDecoderWrapper(nn.Module):
|
@@ -146,11 +141,6 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
146
141
|
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
147
142
|
by subclasses to modify or add custom attributes as necessary.
|
148
143
|
"""
|
149
|
-
if self.use_attention_mask:
|
150
|
-
register_rbln_custom_paged_attention()
|
151
|
-
else:
|
152
|
-
register_rbln_custom_paged_causal_attention()
|
153
|
-
|
154
144
|
self.num_layers = self.config.decoder_layers
|
155
145
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
156
146
|
|
@@ -176,16 +166,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
176
166
|
encoder_attention_mask,
|
177
167
|
cache_position,
|
178
168
|
block_tables,
|
179
|
-
|
180
|
-
*self_kv_cache,
|
169
|
+
*kv_cache,
|
181
170
|
) = args
|
182
171
|
|
183
172
|
else:
|
184
173
|
attention_mask = None
|
185
|
-
(input_ids, encoder_attention_mask, cache_position, block_tables,
|
174
|
+
(input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
|
186
175
|
|
187
176
|
self_past_key_values = ()
|
188
177
|
cross_past_key_values = ()
|
178
|
+
self_kv_cache = kv_cache[self.num_layers * 2 :]
|
179
|
+
cross_kv_cache = kv_cache[: self.num_layers * 2]
|
189
180
|
for i in range(0, self.num_layers * 2, 2):
|
190
181
|
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
191
182
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
@@ -12,4 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from ....ops import paged_add_softmax_attn_decode
|
16
|
+
from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
|
15
17
|
from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
24
|
+
pass
|
@@ -13,106 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
17
17
|
|
18
|
-
import rebel
|
19
18
|
import torch
|
20
|
-
from transformers import
|
21
|
-
AutoModelForTextEncoding,
|
22
|
-
PretrainedConfig,
|
23
|
-
T5EncoderModel,
|
24
|
-
T5ForConditionalGeneration,
|
25
|
-
)
|
26
|
-
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
19
|
+
from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
|
27
20
|
|
28
|
-
from
|
29
|
-
from ....modeling import RBLNModel
|
30
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
|
-
from ....utils.logging import get_logger
|
32
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
21
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
33
22
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
23
|
+
from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
|
34
24
|
from .t5_architecture import T5Wrapper
|
35
25
|
|
36
26
|
|
37
|
-
logger = get_logger()
|
38
|
-
|
39
27
|
if TYPE_CHECKING:
|
40
|
-
from transformers import
|
41
|
-
|
42
|
-
|
43
|
-
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
44
|
-
def forward(
|
45
|
-
self,
|
46
|
-
input_ids: torch.LongTensor,
|
47
|
-
attention_mask: torch.FloatTensor,
|
48
|
-
head_mask: torch.FloatTensor,
|
49
|
-
inputs_embeds: torch.FloatTensor,
|
50
|
-
**kwargs,
|
51
|
-
):
|
52
|
-
return super().forward(
|
53
|
-
input_ids,
|
54
|
-
attention_mask,
|
55
|
-
head_mask,
|
56
|
-
inputs_embeds,
|
57
|
-
**kwargs,
|
58
|
-
)
|
59
|
-
|
60
|
-
|
61
|
-
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
62
|
-
mandatory_members = ["main_input_name"]
|
63
|
-
|
64
|
-
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
65
|
-
_ = super().forward(*args, **kwargs)
|
66
|
-
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
67
|
-
|
28
|
+
from transformers import PreTrainedModel
|
68
29
|
|
69
|
-
|
70
|
-
mandatory_members = ["main_input_name"]
|
71
|
-
|
72
|
-
def __init__(
|
73
|
-
self,
|
74
|
-
runtime: rebel.Runtime,
|
75
|
-
batch_size: int,
|
76
|
-
dec_max_seq_len: int,
|
77
|
-
**kwargs: Any,
|
78
|
-
) -> None:
|
79
|
-
super().__init__(runtime, **kwargs)
|
80
|
-
self.batch_size = batch_size
|
81
|
-
self.dec_max_seq_len = dec_max_seq_len
|
82
|
-
|
83
|
-
def forward(
|
84
|
-
self,
|
85
|
-
decoder_input_ids: Optional[torch.LongTensor] = None,
|
86
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
87
|
-
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
88
|
-
cache_position: Optional[torch.Tensor] = None,
|
89
|
-
**kwargs,
|
90
|
-
) -> Tuple[torch.FloatTensor]:
|
91
|
-
batch_size = decoder_input_ids.shape[0]
|
92
|
-
if batch_size != self.batch_size:
|
93
|
-
raise RuntimeError(
|
94
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
95
|
-
)
|
96
|
-
|
97
|
-
if batch_size != cache_position.shape[0]:
|
98
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
99
|
-
|
100
|
-
for b_idx in range(self.batch_size):
|
101
|
-
decoding_step = cache_position[b_idx].item()
|
102
|
-
if not (0 <= decoding_step < self.dec_max_seq_len):
|
103
|
-
raise ValueError(
|
104
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
105
|
-
)
|
106
|
-
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
107
|
-
|
108
|
-
lm_logits = super().forward(
|
109
|
-
decoder_input_ids,
|
110
|
-
decoder_attention_mask,
|
111
|
-
attention_mask,
|
112
|
-
cache_position,
|
113
|
-
)
|
114
|
-
|
115
|
-
return Seq2SeqLMOutput(logits=lm_logits)
|
30
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
116
31
|
|
117
32
|
|
118
33
|
class T5EncoderWrapper(torch.nn.Module):
|
@@ -125,149 +40,35 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
125
40
|
return self.model(*args, **kwargs, return_dict=False)
|
126
41
|
|
127
42
|
|
128
|
-
class RBLNT5EncoderModel(
|
43
|
+
class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
|
129
44
|
auto_model_class = AutoModelForTextEncoding
|
130
45
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
131
46
|
|
132
|
-
def __post_init__(self, **kwargs):
|
133
|
-
self.model = RBLNRuntimeModel(runtime=self.model[0])
|
134
|
-
|
135
47
|
@classmethod
|
136
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
48
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
|
137
49
|
return T5EncoderWrapper(model)
|
138
50
|
|
139
51
|
@classmethod
|
140
|
-
def update_rbln_config_using_pipe(
|
141
|
-
batch_size = rbln_config.get("batch_size", 1)
|
142
|
-
max_sequence_length = rbln_config.get("max_sequence_length", 256)
|
143
|
-
model_input_names = ["input_ids"]
|
144
|
-
|
145
|
-
rbln_config.update(
|
146
|
-
{
|
147
|
-
"batch_size": batch_size,
|
148
|
-
"max_seq_len": max_sequence_length,
|
149
|
-
"model_input_names": model_input_names,
|
150
|
-
}
|
151
|
-
)
|
152
|
-
|
153
|
-
return rbln_config
|
154
|
-
|
155
|
-
@classmethod
|
156
|
-
def _get_rbln_config(
|
52
|
+
def update_rbln_config_using_pipe(
|
157
53
|
cls,
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
) ->
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
max_position_embeddings = getattr(model_config, "n_positions", None)
|
167
|
-
|
168
|
-
if rbln_max_seq_len is None:
|
169
|
-
rbln_max_seq_len = max_position_embeddings
|
170
|
-
if rbln_max_seq_len is None:
|
171
|
-
for tokenizer in preprocessors:
|
172
|
-
if hasattr(tokenizer, "model_max_length"):
|
173
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
174
|
-
break
|
175
|
-
if rbln_max_seq_len is None:
|
176
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
177
|
-
|
178
|
-
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
179
|
-
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
180
|
-
|
181
|
-
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
182
|
-
|
183
|
-
if rbln_model_input_names is None:
|
184
|
-
for tokenizer in preprocessors:
|
185
|
-
if hasattr(tokenizer, "model_input_names"):
|
186
|
-
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
187
|
-
|
188
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
189
|
-
if invalid_params:
|
190
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
191
|
-
break
|
192
|
-
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
193
|
-
rbln_model_input_names = cls.rbln_model_input_names
|
194
|
-
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
195
|
-
raise ValueError(
|
196
|
-
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
197
|
-
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
|
198
|
-
)
|
199
|
-
else:
|
200
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
201
|
-
if invalid_params:
|
202
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
203
|
-
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
204
|
-
|
205
|
-
if rbln_batch_size is None:
|
206
|
-
rbln_batch_size = 1
|
207
|
-
|
208
|
-
input_info = [
|
209
|
-
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
210
|
-
for model_input_name in rbln_model_input_names
|
211
|
-
]
|
212
|
-
|
213
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
214
|
-
|
215
|
-
rbln_config = RBLNConfig(
|
216
|
-
rbln_cls=cls.__name__,
|
217
|
-
compile_cfgs=[rbln_compile_config],
|
218
|
-
rbln_kwargs=rbln_kwargs,
|
219
|
-
)
|
220
|
-
|
221
|
-
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
54
|
+
pipe: "RBLNDiffusionMixin",
|
55
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
56
|
+
submodule_name: str,
|
57
|
+
) -> "RBLNDiffusionMixinConfig":
|
58
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
59
|
+
submodule_config.max_seq_len = rbln_config.max_seq_len or 256
|
60
|
+
submodule_config.model_input_names = ["input_ids"]
|
222
61
|
return rbln_config
|
223
62
|
|
224
|
-
def forward(
|
225
|
-
self,
|
226
|
-
input_ids: Optional[torch.LongTensor] = None,
|
227
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
228
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
229
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
230
|
-
output_attentions: Optional[bool] = None,
|
231
|
-
output_hidden_states: Optional[bool] = None,
|
232
|
-
return_dict: Optional[bool] = None,
|
233
|
-
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
234
|
-
encoder_outputs = self.model(
|
235
|
-
input_ids=input_ids,
|
236
|
-
attention_mask=attention_mask,
|
237
|
-
inputs_embeds=inputs_embeds,
|
238
|
-
head_mask=head_mask,
|
239
|
-
output_attentions=output_attentions,
|
240
|
-
output_hidden_states=output_hidden_states,
|
241
|
-
return_dict=return_dict,
|
242
|
-
)
|
243
|
-
if not return_dict:
|
244
|
-
return (encoder_outputs,)
|
245
|
-
else:
|
246
|
-
return BaseModelOutput(last_hidden_state=encoder_outputs)
|
247
|
-
|
248
63
|
|
249
64
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
250
|
-
|
251
|
-
batch_size = self.rbln_config.model_cfg["batch_size"]
|
252
|
-
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
253
|
-
|
254
|
-
self.encoder = RBLNRuntimeEncoder(
|
255
|
-
runtime=self.model[0],
|
256
|
-
main_input_name="input_ids",
|
257
|
-
)
|
258
|
-
self.decoder = RBLNRuntimeDecoder(
|
259
|
-
runtime=self.model[1],
|
260
|
-
main_input_name="input_ids",
|
261
|
-
batch_size=batch_size,
|
262
|
-
dec_max_seq_len=dec_max_seq_len,
|
263
|
-
)
|
65
|
+
support_causal_attn = False
|
264
66
|
|
265
67
|
@classmethod
|
266
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
|
68
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
|
69
|
+
return T5Wrapper(
|
70
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
|
71
|
+
)
|
271
72
|
|
272
73
|
def __getattr__(self, __name: str) -> Any:
|
273
74
|
def redirect(func):
|
@@ -279,139 +80,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
279
80
|
return redirect(val)
|
280
81
|
|
281
82
|
return val
|
282
|
-
|
283
|
-
@classmethod
|
284
|
-
def _get_rbln_config(
|
285
|
-
cls,
|
286
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
287
|
-
model_config: "PretrainedConfig",
|
288
|
-
rbln_kwargs: Dict[str, Any] = {},
|
289
|
-
) -> RBLNConfig:
|
290
|
-
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
291
|
-
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
292
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
293
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
294
|
-
|
295
|
-
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
296
|
-
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
297
|
-
d_kv = (
|
298
|
-
model_config.d_kv
|
299
|
-
if hasattr(model_config, "d_kv")
|
300
|
-
else model_config.d_model // model_config.encoder_attention_heads
|
301
|
-
)
|
302
|
-
|
303
|
-
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
304
|
-
model_config, "max_position_embeddings", None
|
305
|
-
)
|
306
|
-
|
307
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
308
|
-
if rbln_pad_token_id is None:
|
309
|
-
rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
|
310
|
-
if rbln_pad_token_id is None:
|
311
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
312
|
-
if rbln_pad_token_id is None:
|
313
|
-
rbln_pad_token_id = -1
|
314
|
-
|
315
|
-
if rbln_enc_max_seq_len is None:
|
316
|
-
rbln_enc_max_seq_len = max_position_embeddings
|
317
|
-
if rbln_enc_max_seq_len is None:
|
318
|
-
for tokenizer in preprocessors:
|
319
|
-
if hasattr(tokenizer, "model_max_length"):
|
320
|
-
rbln_enc_max_seq_len = tokenizer.model_max_length
|
321
|
-
break
|
322
|
-
if rbln_enc_max_seq_len is None:
|
323
|
-
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
324
|
-
if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
|
325
|
-
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
326
|
-
|
327
|
-
if rbln_dec_max_seq_len is None:
|
328
|
-
rbln_dec_max_seq_len = max_position_embeddings
|
329
|
-
if rbln_dec_max_seq_len is None:
|
330
|
-
for tokenizer in preprocessors:
|
331
|
-
if hasattr(tokenizer, "model_max_length"):
|
332
|
-
rbln_dec_max_seq_len = tokenizer.model_max_length
|
333
|
-
break
|
334
|
-
if rbln_dec_max_seq_len is None:
|
335
|
-
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
336
|
-
|
337
|
-
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
338
|
-
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
339
|
-
|
340
|
-
# model input info
|
341
|
-
enc_input_info = [
|
342
|
-
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
343
|
-
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
344
|
-
(
|
345
|
-
"cross_key_value_states",
|
346
|
-
[
|
347
|
-
n_layer * 2,
|
348
|
-
rbln_batch_size,
|
349
|
-
n_head,
|
350
|
-
rbln_enc_max_seq_len,
|
351
|
-
d_kv,
|
352
|
-
],
|
353
|
-
"float32",
|
354
|
-
),
|
355
|
-
("block_tables", [1], "int16"),
|
356
|
-
]
|
357
|
-
|
358
|
-
dec_input_info = [
|
359
|
-
("input_ids", [rbln_batch_size, 1], "int64"),
|
360
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
361
|
-
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
362
|
-
(
|
363
|
-
"cache_position",
|
364
|
-
[rbln_batch_size, 1],
|
365
|
-
"int32",
|
366
|
-
),
|
367
|
-
]
|
368
|
-
dec_input_info.extend(
|
369
|
-
[
|
370
|
-
(
|
371
|
-
"cross_key_value_states",
|
372
|
-
[
|
373
|
-
n_layer * 2,
|
374
|
-
rbln_batch_size,
|
375
|
-
n_head,
|
376
|
-
rbln_enc_max_seq_len,
|
377
|
-
d_kv,
|
378
|
-
],
|
379
|
-
"float32",
|
380
|
-
)
|
381
|
-
]
|
382
|
-
)
|
383
|
-
dec_input_info.extend(
|
384
|
-
[
|
385
|
-
(
|
386
|
-
f"self_key_value_states_{i}",
|
387
|
-
[
|
388
|
-
rbln_batch_size,
|
389
|
-
n_head,
|
390
|
-
rbln_dec_max_seq_len,
|
391
|
-
d_kv,
|
392
|
-
],
|
393
|
-
"float32",
|
394
|
-
)
|
395
|
-
for i in range(n_layer * 2)
|
396
|
-
]
|
397
|
-
)
|
398
|
-
|
399
|
-
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
400
|
-
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
401
|
-
|
402
|
-
rbln_config = RBLNConfig(
|
403
|
-
rbln_cls=cls.__name__,
|
404
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
405
|
-
rbln_kwargs=rbln_kwargs,
|
406
|
-
)
|
407
|
-
|
408
|
-
rbln_config.model_cfg.update(
|
409
|
-
{
|
410
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
411
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
412
|
-
"batch_size": rbln_batch_size,
|
413
|
-
"pad_token_id": rbln_pad_token_id,
|
414
|
-
}
|
415
|
-
)
|
416
|
-
|
417
|
-
return rbln_config
|
@@ -18,7 +18,6 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import register_rbln_custom_add_softmax_attention
|
22
21
|
from ..seq2seq.seq2seq_architecture import (
|
23
22
|
Seq2SeqDecoder,
|
24
23
|
Seq2SeqDecoderLayer,
|
@@ -55,7 +54,6 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
55
54
|
|
56
55
|
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
57
56
|
def __post_init__(self, model, dec_max_seq_len: int = None):
|
58
|
-
register_rbln_custom_add_softmax_attention()
|
59
57
|
self.num_layers = self.config.num_layers
|
60
58
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
61
59
|
|
@@ -77,11 +75,13 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
77
75
|
attention_mask,
|
78
76
|
encoder_attention_mask,
|
79
77
|
cache_position,
|
80
|
-
|
81
|
-
*
|
78
|
+
block_tables,
|
79
|
+
*kv_cache,
|
82
80
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
83
81
|
self_past_key_values = ()
|
84
82
|
cross_past_key_values = ()
|
83
|
+
self_kv_cache = kv_cache[self.num_layers * 2 :]
|
84
|
+
cross_kv_cache = kv_cache[: self.num_layers * 2]
|
85
85
|
|
86
86
|
for i in range(0, self.num_layers * 2, 2):
|
87
87
|
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
@@ -95,6 +95,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
95
95
|
self_past_key_values=self_past_key_values,
|
96
96
|
cross_past_key_values=cross_past_key_values,
|
97
97
|
cache_position=cache_position,
|
98
|
+
block_tables=block_tables,
|
98
99
|
)
|
99
100
|
|
100
101
|
return lm_logits
|
@@ -162,7 +163,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
162
163
|
self.out_proj = self._original_mod.o
|
163
164
|
self.num_heads = self._original_mod.n_heads
|
164
165
|
self.head_dim = self._original_mod.key_value_proj_dim
|
165
|
-
self.attn_decode = torch.ops.rbln_custom_ops.
|
166
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
166
167
|
|
167
168
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
168
169
|
query_states = self.q_proj(hidden_states)
|
@@ -176,6 +177,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
176
177
|
past_key_value: Tuple[torch.Tensor],
|
177
178
|
attention_mask: torch.Tensor,
|
178
179
|
cache_position: torch.Tensor,
|
180
|
+
block_tables: torch.Tensor,
|
179
181
|
**kwargs,
|
180
182
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
181
183
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -185,6 +187,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
185
187
|
key_states = self._shape(key_states, -1, bsz)
|
186
188
|
value_states = self._shape(value_states, -1, bsz)
|
187
189
|
|
190
|
+
block_size = past_key_value[0].shape[-2]
|
188
191
|
attn_output = self.attn_decode(
|
189
192
|
query_states,
|
190
193
|
key_states,
|
@@ -196,6 +199,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
196
199
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
197
200
|
cache_position,
|
198
201
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
202
|
+
block_tables,
|
203
|
+
block_size,
|
199
204
|
)
|
200
205
|
|
201
206
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
|
25
|
+
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
26
|
+
from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
|
optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from ....configuration_utils import RBLNModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
batch_size: Optional[int] = None,
|
10
|
+
enc_max_seq_len: Optional[int] = None,
|
11
|
+
dec_max_seq_len: Optional[int] = None,
|
12
|
+
num_parallel_samples: Optional[int] = None,
|
13
|
+
**kwargs,
|
14
|
+
):
|
15
|
+
"""
|
16
|
+
Args:
|
17
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
18
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
19
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
20
|
+
num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
|
21
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
22
|
+
|
23
|
+
Raises:
|
24
|
+
ValueError: If batch_size is not a positive integer.
|
25
|
+
"""
|
26
|
+
super().__init__(**kwargs)
|
27
|
+
|
28
|
+
self.batch_size = batch_size or 1
|
29
|
+
if not isinstance(self.batch_size, int) or self.batch_size <= 0:
|
30
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
31
|
+
|
32
|
+
self.enc_max_seq_len = enc_max_seq_len
|
33
|
+
self.dec_max_seq_len = dec_max_seq_len
|
34
|
+
self.num_parallel_samples = num_parallel_samples
|