optimum-rbln 0.7.4a3__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -80
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a3.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -12,4 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from .configuration_llava_next import RBLNLlavaNextForConditionalGenerationConfig
|
15
16
|
from .modeling_llava_next import RBLNLlavaNextForConditionalGeneration
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
|
21
|
+
submodules = ["vision_tower", "language_model"]
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
batch_size: Optional[int] = None,
|
26
|
+
vision_tower: Optional[RBLNModelConfig] = None,
|
27
|
+
language_model: Optional[RBLNModelConfig] = None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Args:
|
32
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
33
|
+
vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
34
|
+
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
35
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
36
|
+
|
37
|
+
Raises:
|
38
|
+
ValueError: If batch_size is not a positive integer.
|
39
|
+
"""
|
40
|
+
super().__init__(**kwargs)
|
41
|
+
self.batch_size = batch_size or 1
|
42
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
43
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
44
|
+
|
45
|
+
self.vision_tower = vision_tower
|
46
|
+
self.language_model = language_model
|
@@ -26,8 +26,8 @@ from transformers import (
|
|
26
26
|
)
|
27
27
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
28
28
|
|
29
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
29
30
|
from ....modeling import RBLNModel
|
30
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
33
33
|
|
@@ -134,7 +134,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
134
134
|
model: "LlavaNextForConditionalGeneration",
|
135
135
|
save_dir_path: Path,
|
136
136
|
subfolder: str,
|
137
|
-
rbln_config:
|
137
|
+
rbln_config: RBLNModelConfig,
|
138
138
|
):
|
139
139
|
"""
|
140
140
|
If you are unavoidably running on a CPU rather than an RBLN device,
|
@@ -161,42 +161,31 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
161
161
|
return self.language_model.get_input_embeddings()
|
162
162
|
|
163
163
|
@classmethod
|
164
|
-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config:
|
164
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
165
165
|
return model.multi_modal_projector
|
166
166
|
|
167
167
|
@classmethod
|
168
|
-
def
|
168
|
+
def _update_rbln_config(
|
169
169
|
cls,
|
170
170
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
171
|
+
model: Optional["PreTrainedModel"] = None,
|
171
172
|
model_config: Optional["PretrainedConfig"] = None,
|
172
|
-
|
173
|
-
) ->
|
174
|
-
vision_feature_select_strategy = rbln_kwargs.get("vision_feature_select_strategy", None)
|
175
|
-
|
176
|
-
# 1. Multi-modal projection layer
|
177
|
-
batch_size = rbln_kwargs.get("rbln_batch_size", None)
|
178
|
-
if batch_size is None:
|
179
|
-
batch_size = 1
|
180
|
-
|
173
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
174
|
+
) -> RBLNModelConfig:
|
181
175
|
feature_size = model_config.vision_config.hidden_size
|
182
176
|
|
183
|
-
# See forward function to see more details.
|
184
|
-
vision_feature_select_strategy = (
|
185
|
-
vision_feature_select_strategy
|
186
|
-
if vision_feature_select_strategy is not None
|
187
|
-
else model_config.vision_feature_select_strategy
|
188
|
-
)
|
189
|
-
|
190
177
|
# Calculating `num_positions` : See CLIPVisionEmbeddings of transformers for more details.
|
191
178
|
num_positions = (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2 + 1
|
192
|
-
if vision_feature_select_strategy == "default":
|
179
|
+
if model_config.vision_feature_select_strategy == "default":
|
193
180
|
selected_image_feature_dim = num_positions - 1
|
194
181
|
else:
|
195
182
|
selected_image_feature_dim = num_positions
|
196
183
|
|
197
|
-
input_info = [
|
184
|
+
input_info = [
|
185
|
+
("image_features", [rbln_config.batch_size, selected_image_feature_dim, feature_size], "float32")
|
186
|
+
]
|
198
187
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
199
|
-
rbln_config
|
188
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
200
189
|
return rbln_config
|
201
190
|
|
202
191
|
def prepare_inputs_for_generation(
|
@@ -20,4 +20,5 @@ this_path = os.path.abspath(__file__)
|
|
20
20
|
local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
|
21
21
|
environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
|
22
22
|
|
23
|
+
from .configuration_midm import RBLNMidmLMHeadModelConfig
|
23
24
|
from .modeling_midm import RBLNMidmLMHeadModel
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNMidmLMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
+
pass
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
+
pass
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
+
pass
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
16
|
+
|
17
|
+
|
18
|
+
class RBLNQwen2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
+
pass
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
|
19
|
+
from ....configuration_utils import RBLNModelConfig
|
20
|
+
from ....utils.logging import get_logger
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger()
|
24
|
+
|
25
|
+
|
26
|
+
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
batch_size: Optional[int] = None,
|
30
|
+
enc_max_seq_len: Optional[int] = None,
|
31
|
+
dec_max_seq_len: Optional[int] = None,
|
32
|
+
use_attention_mask: Optional[bool] = None,
|
33
|
+
pad_token_id: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
40
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
41
|
+
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
42
|
+
This is automatically set to True for RBLN-CA02 devices.
|
43
|
+
pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
|
44
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If batch_size is not a positive integer.
|
48
|
+
"""
|
49
|
+
super().__init__(**kwargs)
|
50
|
+
self.batch_size = batch_size or 1
|
51
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
52
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
53
|
+
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|
65
|
+
|
66
|
+
self.pad_token_id = pad_token_id
|
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
24
24
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
25
26
|
from ....modeling import RBLNModel
|
26
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
27
27
|
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
|
+
from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
|
29
30
|
|
30
31
|
|
31
32
|
logger = get_logger(__name__)
|
@@ -118,9 +119,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
118
119
|
support_causal_attn = None
|
119
120
|
|
120
121
|
def __post_init__(self, **kwargs):
|
121
|
-
batch_size = self.rbln_config.
|
122
|
-
dec_max_seq_len = self.rbln_config.
|
123
|
-
self.use_attention_mask = self.rbln_config.
|
122
|
+
batch_size = self.rbln_config.batch_size
|
123
|
+
dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
124
125
|
|
125
126
|
self.encoder = RBLNRuntimeEncoder(
|
126
127
|
runtime=self.model[0],
|
@@ -136,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
136
137
|
|
137
138
|
@classmethod
|
138
139
|
@torch.inference_mode()
|
139
|
-
def get_compiled_model(cls, model: PreTrainedModel, rbln_config:
|
140
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
140
141
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
141
142
|
|
142
143
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -177,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
177
178
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
178
179
|
|
179
180
|
@classmethod
|
180
|
-
def
|
181
|
+
def _update_rbln_config(
|
181
182
|
cls,
|
182
183
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
190
|
-
|
191
|
-
if cls.support_causal_attn:
|
192
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
193
|
-
if rbln_use_attention_mask is None:
|
194
|
-
rbln_use_attention_mask = False
|
195
|
-
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
196
|
-
if rbln_npu == "RBLN-CA02":
|
197
|
-
rbln_use_attention_mask = True
|
198
|
-
else:
|
199
|
-
rbln_use_attention_mask = True
|
184
|
+
model: Optional["PreTrainedModel"] = None,
|
185
|
+
model_config: Optional["PretrainedConfig"] = None,
|
186
|
+
rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
|
187
|
+
) -> RBLNModelForSeq2SeqLMConfig:
|
188
|
+
if not cls.support_causal_attn:
|
189
|
+
rbln_config.use_attention_mask = True
|
200
190
|
|
201
191
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
202
192
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -210,43 +200,44 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
210
200
|
model_config, "max_position_embeddings", None
|
211
201
|
)
|
212
202
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
if max_position_embeddings is not None and
|
231
|
-
raise ValueError("`
|
232
|
-
|
233
|
-
if
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
203
|
+
pad_token_id = getattr(model_config, "pad_token_id", None)
|
204
|
+
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
205
|
+
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
206
|
+
pad_token_id = pad_token_id or -1
|
207
|
+
rbln_config.pad_token_id = pad_token_id
|
208
|
+
|
209
|
+
if rbln_config.enc_max_seq_len is None:
|
210
|
+
enc_max_seq_len = max_position_embeddings
|
211
|
+
for tokenizer in preprocessors:
|
212
|
+
if hasattr(tokenizer, "model_max_length"):
|
213
|
+
enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
|
214
|
+
break
|
215
|
+
|
216
|
+
if enc_max_seq_len is None:
|
217
|
+
raise ValueError("`enc_max_seq_len` should be specified!")
|
218
|
+
rbln_config.enc_max_seq_len = enc_max_seq_len
|
219
|
+
|
220
|
+
if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
|
221
|
+
raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
222
|
+
|
223
|
+
if rbln_config.dec_max_seq_len is None:
|
224
|
+
dec_max_seq_len = max_position_embeddings
|
225
|
+
for tokenizer in preprocessors:
|
226
|
+
if hasattr(tokenizer, "model_max_length"):
|
227
|
+
dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
|
228
|
+
break
|
229
|
+
|
230
|
+
if dec_max_seq_len is None:
|
231
|
+
raise ValueError("`dec_max_seq_len` should be specified!")
|
232
|
+
rbln_config.dec_max_seq_len = dec_max_seq_len
|
233
|
+
|
234
|
+
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
235
|
+
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
245
236
|
|
246
237
|
# model input info
|
247
238
|
enc_input_info = [
|
248
|
-
("input_ids", [1,
|
249
|
-
("attention_mask", [1,
|
239
|
+
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
240
|
+
("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
|
250
241
|
("block_tables", [1], "int16"),
|
251
242
|
]
|
252
243
|
enc_input_info.extend(
|
@@ -254,9 +245,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
254
245
|
(
|
255
246
|
f"cross_key_value_states_{i}",
|
256
247
|
[
|
257
|
-
|
248
|
+
rbln_config.batch_size,
|
258
249
|
n_head,
|
259
|
-
|
250
|
+
rbln_config.enc_max_seq_len,
|
260
251
|
d_kv,
|
261
252
|
],
|
262
253
|
"float32",
|
@@ -266,23 +257,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
266
257
|
)
|
267
258
|
|
268
259
|
dec_input_info = [
|
269
|
-
("input_ids", [
|
270
|
-
("encoder_attention_mask", [
|
260
|
+
("input_ids", [rbln_config.batch_size, 1], "int64"),
|
261
|
+
("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
|
271
262
|
(
|
272
263
|
"cache_position",
|
273
|
-
[
|
264
|
+
[rbln_config.batch_size, 1],
|
274
265
|
"int32",
|
275
266
|
),
|
276
|
-
("block_tables", [
|
267
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
277
268
|
]
|
278
269
|
dec_input_info.extend(
|
279
270
|
[
|
280
271
|
(
|
281
272
|
f"cross_key_value_states_{i}",
|
282
273
|
[
|
283
|
-
|
274
|
+
rbln_config.batch_size,
|
284
275
|
n_head,
|
285
|
-
|
276
|
+
rbln_config.enc_max_seq_len,
|
286
277
|
d_kv,
|
287
278
|
],
|
288
279
|
"float32",
|
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
295
286
|
(
|
296
287
|
f"self_key_value_states_{i}",
|
297
288
|
[
|
298
|
-
|
289
|
+
rbln_config.batch_size,
|
299
290
|
n_head,
|
300
|
-
|
291
|
+
rbln_config.dec_max_seq_len,
|
301
292
|
d_kv,
|
302
293
|
],
|
303
294
|
"float32",
|
@@ -306,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
306
297
|
]
|
307
298
|
)
|
308
299
|
|
309
|
-
if
|
310
|
-
dec_input_info.insert(
|
300
|
+
if rbln_config.use_attention_mask:
|
301
|
+
dec_input_info.insert(
|
302
|
+
1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
303
|
+
)
|
311
304
|
|
312
305
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
313
306
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
314
307
|
|
315
|
-
rbln_config
|
316
|
-
rbln_cls=cls.__name__,
|
317
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
318
|
-
rbln_kwargs=rbln_kwargs,
|
319
|
-
)
|
320
|
-
|
321
|
-
rbln_config.model_cfg.update(
|
322
|
-
{
|
323
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
324
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
325
|
-
"batch_size": rbln_batch_size,
|
326
|
-
"pad_token_id": rbln_pad_token_id,
|
327
|
-
"use_attention_mask": rbln_use_attention_mask,
|
328
|
-
}
|
329
|
-
)
|
330
|
-
|
308
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
331
309
|
return rbln_config
|
332
310
|
|
333
311
|
@classmethod
|
334
312
|
def _create_runtimes(
|
335
313
|
cls,
|
336
314
|
compiled_models: List[rebel.RBLNCompiledModel],
|
337
|
-
|
338
|
-
activate_profiler: Optional[bool] = None,
|
315
|
+
rbln_config: RBLNModelForSeq2SeqLMConfig,
|
339
316
|
) -> List[rebel.Runtime]:
|
340
|
-
if any(model_name not in
|
317
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
341
318
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
342
319
|
|
343
320
|
return [
|
344
|
-
|
345
|
-
|
321
|
+
rebel.Runtime(
|
322
|
+
compiled_models[0],
|
323
|
+
tensor_type="pt",
|
324
|
+
device=rbln_config.device_map["encoder"],
|
325
|
+
activate_profiler=rbln_config.activate_profiler,
|
346
326
|
),
|
347
|
-
|
348
|
-
|
327
|
+
rebel.Runtime(
|
328
|
+
compiled_models[1],
|
329
|
+
tensor_type="pt",
|
330
|
+
device=rbln_config.device_map["decoder"],
|
331
|
+
activate_profiler=rbln_config.activate_profiler,
|
349
332
|
),
|
350
333
|
]
|
351
334
|
|
@@ -367,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
367
350
|
):
|
368
351
|
cur_seq_len = input_ids.shape[-1]
|
369
352
|
cache_position = cur_seq_len - 1
|
370
|
-
max_seq_len = self.rbln_config.
|
353
|
+
max_seq_len = self.rbln_config.dec_max_seq_len
|
371
354
|
decoder_batch_size = input_ids.shape[0]
|
372
355
|
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
373
356
|
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
@@ -387,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
387
370
|
**kwargs,
|
388
371
|
) -> Tuple[torch.FloatTensor]:
|
389
372
|
# common decoder
|
390
|
-
cache_position = torch.full((self.rbln_config.
|
373
|
+
cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
|
391
374
|
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
392
375
|
|
393
376
|
return Seq2SeqLMOutput(
|
@@ -421,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
421
404
|
batch_size, input_len = inputs_tensor.shape
|
422
405
|
inputs_tensor = torch.nn.functional.pad(
|
423
406
|
inputs_tensor,
|
424
|
-
(0, self.rbln_config.
|
425
|
-
value=self.rbln_config.
|
407
|
+
(0, self.rbln_config.enc_max_seq_len - input_len),
|
408
|
+
value=self.rbln_config.pad_token_id,
|
426
409
|
)
|
427
410
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
428
|
-
model_kwargs["attention_mask"], (0, self.rbln_config.
|
411
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
429
412
|
)
|
430
413
|
|
431
414
|
# 3. make sure that encoder returns `ModelOutput`
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
24
|
+
pass
|