optimum-rbln 0.7.4a3__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -80
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a3.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_llava_next import RBLNLlavaNextForConditionalGenerationConfig
15
16
  from .modeling_llava_next import RBLNLlavaNextForConditionalGeneration
@@ -0,0 +1,46 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
21
+ submodules = ["vision_tower", "language_model"]
22
+
23
+ def __init__(
24
+ self,
25
+ batch_size: Optional[int] = None,
26
+ vision_tower: Optional[RBLNModelConfig] = None,
27
+ language_model: Optional[RBLNModelConfig] = None,
28
+ **kwargs,
29
+ ):
30
+ """
31
+ Args:
32
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
33
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
34
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
35
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
36
+
37
+ Raises:
38
+ ValueError: If batch_size is not a positive integer.
39
+ """
40
+ super().__init__(**kwargs)
41
+ self.batch_size = batch_size or 1
42
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
43
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
44
+
45
+ self.vision_tower = vision_tower
46
+ self.language_model = language_model
@@ -26,8 +26,8 @@ from transformers import (
26
26
  )
27
27
  from transformers.modeling_outputs import BaseModelOutputWithPooling
28
28
 
29
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
29
30
  from ....modeling import RBLNModel
30
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
33
33
 
@@ -134,7 +134,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
134
134
  model: "LlavaNextForConditionalGeneration",
135
135
  save_dir_path: Path,
136
136
  subfolder: str,
137
- rbln_config: RBLNConfig,
137
+ rbln_config: RBLNModelConfig,
138
138
  ):
139
139
  """
140
140
  If you are unavoidably running on a CPU rather than an RBLN device,
@@ -161,42 +161,31 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
161
161
  return self.language_model.get_input_embeddings()
162
162
 
163
163
  @classmethod
164
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
164
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
165
165
  return model.multi_modal_projector
166
166
 
167
167
  @classmethod
168
- def _get_rbln_config(
168
+ def _update_rbln_config(
169
169
  cls,
170
170
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
171
+ model: Optional["PreTrainedModel"] = None,
171
172
  model_config: Optional["PretrainedConfig"] = None,
172
- rbln_kwargs={},
173
- ) -> RBLNConfig:
174
- vision_feature_select_strategy = rbln_kwargs.get("vision_feature_select_strategy", None)
175
-
176
- # 1. Multi-modal projection layer
177
- batch_size = rbln_kwargs.get("rbln_batch_size", None)
178
- if batch_size is None:
179
- batch_size = 1
180
-
173
+ rbln_config: Optional[RBLNModelConfig] = None,
174
+ ) -> RBLNModelConfig:
181
175
  feature_size = model_config.vision_config.hidden_size
182
176
 
183
- # See forward function to see more details.
184
- vision_feature_select_strategy = (
185
- vision_feature_select_strategy
186
- if vision_feature_select_strategy is not None
187
- else model_config.vision_feature_select_strategy
188
- )
189
-
190
177
  # Calculating `num_positions` : See CLIPVisionEmbeddings of transformers for more details.
191
178
  num_positions = (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2 + 1
192
- if vision_feature_select_strategy == "default":
179
+ if model_config.vision_feature_select_strategy == "default":
193
180
  selected_image_feature_dim = num_positions - 1
194
181
  else:
195
182
  selected_image_feature_dim = num_positions
196
183
 
197
- input_info = [("image_features", [batch_size, selected_image_feature_dim, feature_size], "float32")]
184
+ input_info = [
185
+ ("image_features", [rbln_config.batch_size, selected_image_feature_dim, feature_size], "float32")
186
+ ]
198
187
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
199
- rbln_config = RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
188
+ rbln_config.set_compile_cfgs([rbln_compile_config])
200
189
  return rbln_config
201
190
 
202
191
  def prepare_inputs_for_generation(
@@ -20,4 +20,5 @@ this_path = os.path.abspath(__file__)
20
20
  local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
21
21
  environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
22
22
 
23
+ from .configuration_midm import RBLNMidmLMHeadModelConfig
23
24
  from .modeling_midm import RBLNMidmLMHeadModel
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNMidmLMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_mistral import RBLNMistralForCausalLMConfig
15
16
  from .modeling_mistral import RBLNMistralForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_phi import RBLNPhiForCausalLMConfig
15
16
  from .modeling_phi import RBLNPhiForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_qwen2 import RBLNQwen2ForCausalLMConfig
15
16
  from .modeling_qwen2 import RBLNQwen2ForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
15
16
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+
22
+
23
+ logger = get_logger()
24
+
25
+
26
+ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ enc_max_seq_len: Optional[int] = None,
31
+ dec_max_seq_len: Optional[int] = None,
32
+ use_attention_mask: Optional[bool] = None,
33
+ pad_token_id: Optional[int] = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Args:
38
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
39
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
+ use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
+ This is automatically set to True for RBLN-CA02 devices.
43
+ pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+
46
+ Raises:
47
+ ValueError: If batch_size is not a positive integer.
48
+ """
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
52
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
53
+
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False
65
+
66
+ self.pad_token_id = pad_token_id
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
+ from ....configuration_utils import RBLNCompileConfig
25
26
  from ....modeling import RBLNModel
26
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
27
27
  from ....utils.logging import get_logger
28
28
  from ....utils.runtime_utils import RBLNPytorchRuntime
29
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
29
30
 
30
31
 
31
32
  logger = get_logger(__name__)
@@ -118,9 +119,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
118
119
  support_causal_attn = None
119
120
 
120
121
  def __post_init__(self, **kwargs):
121
- batch_size = self.rbln_config.model_cfg["batch_size"]
122
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
123
- self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
122
+ batch_size = self.rbln_config.batch_size
123
+ dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.use_attention_mask = self.rbln_config.use_attention_mask
124
125
 
125
126
  self.encoder = RBLNRuntimeEncoder(
126
127
  runtime=self.model[0],
@@ -136,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
136
137
 
137
138
  @classmethod
138
139
  @torch.inference_mode()
139
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
140
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
140
141
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
141
142
 
142
143
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -177,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
177
178
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
178
179
 
179
180
  @classmethod
180
- def _get_rbln_config(
181
+ def _update_rbln_config(
181
182
  cls,
182
183
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
183
- model_config: "PretrainedConfig",
184
- rbln_kwargs: Dict[str, Any] = {},
185
- ) -> RBLNConfig:
186
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
187
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
188
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
189
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
190
-
191
- if cls.support_causal_attn:
192
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
193
- if rbln_use_attention_mask is None:
194
- rbln_use_attention_mask = False
195
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
196
- if rbln_npu == "RBLN-CA02":
197
- rbln_use_attention_mask = True
198
- else:
199
- rbln_use_attention_mask = True
184
+ model: Optional["PreTrainedModel"] = None,
185
+ model_config: Optional["PretrainedConfig"] = None,
186
+ rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
187
+ ) -> RBLNModelForSeq2SeqLMConfig:
188
+ if not cls.support_causal_attn:
189
+ rbln_config.use_attention_mask = True
200
190
 
201
191
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
202
192
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -210,43 +200,44 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
210
200
  model_config, "max_position_embeddings", None
211
201
  )
212
202
 
213
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
214
- if rbln_pad_token_id is None:
215
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
216
- if rbln_pad_token_id is None:
217
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
218
- if rbln_pad_token_id is None:
219
- rbln_pad_token_id = -1
220
-
221
- if rbln_enc_max_seq_len is None:
222
- rbln_enc_max_seq_len = max_position_embeddings
223
- if rbln_enc_max_seq_len is None:
224
- for tokenizer in preprocessors:
225
- if hasattr(tokenizer, "model_max_length"):
226
- rbln_enc_max_seq_len = tokenizer.model_max_length
227
- break
228
- if rbln_enc_max_seq_len is None:
229
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
230
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
231
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
232
-
233
- if rbln_dec_max_seq_len is None:
234
- rbln_dec_max_seq_len = max_position_embeddings
235
- if rbln_dec_max_seq_len is None:
236
- for tokenizer in preprocessors:
237
- if hasattr(tokenizer, "model_max_length"):
238
- rbln_dec_max_seq_len = tokenizer.model_max_length
239
- break
240
- if rbln_dec_max_seq_len is None:
241
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
242
-
243
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
244
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
203
+ pad_token_id = getattr(model_config, "pad_token_id", None)
204
+ pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
205
+ pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
206
+ pad_token_id = pad_token_id or -1
207
+ rbln_config.pad_token_id = pad_token_id
208
+
209
+ if rbln_config.enc_max_seq_len is None:
210
+ enc_max_seq_len = max_position_embeddings
211
+ for tokenizer in preprocessors:
212
+ if hasattr(tokenizer, "model_max_length"):
213
+ enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
214
+ break
215
+
216
+ if enc_max_seq_len is None:
217
+ raise ValueError("`enc_max_seq_len` should be specified!")
218
+ rbln_config.enc_max_seq_len = enc_max_seq_len
219
+
220
+ if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
221
+ raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
222
+
223
+ if rbln_config.dec_max_seq_len is None:
224
+ dec_max_seq_len = max_position_embeddings
225
+ for tokenizer in preprocessors:
226
+ if hasattr(tokenizer, "model_max_length"):
227
+ dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
228
+ break
229
+
230
+ if dec_max_seq_len is None:
231
+ raise ValueError("`dec_max_seq_len` should be specified!")
232
+ rbln_config.dec_max_seq_len = dec_max_seq_len
233
+
234
+ if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
235
+ raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
245
236
 
246
237
  # model input info
247
238
  enc_input_info = [
248
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
249
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
239
+ ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
240
+ ("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
250
241
  ("block_tables", [1], "int16"),
251
242
  ]
252
243
  enc_input_info.extend(
@@ -254,9 +245,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
254
245
  (
255
246
  f"cross_key_value_states_{i}",
256
247
  [
257
- rbln_batch_size,
248
+ rbln_config.batch_size,
258
249
  n_head,
259
- rbln_enc_max_seq_len,
250
+ rbln_config.enc_max_seq_len,
260
251
  d_kv,
261
252
  ],
262
253
  "float32",
@@ -266,23 +257,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
266
257
  )
267
258
 
268
259
  dec_input_info = [
269
- ("input_ids", [rbln_batch_size, 1], "int64"),
270
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
260
+ ("input_ids", [rbln_config.batch_size, 1], "int64"),
261
+ ("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
271
262
  (
272
263
  "cache_position",
273
- [rbln_batch_size, 1],
264
+ [rbln_config.batch_size, 1],
274
265
  "int32",
275
266
  ),
276
- ("block_tables", [rbln_batch_size, 1], "int16"),
267
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
277
268
  ]
278
269
  dec_input_info.extend(
279
270
  [
280
271
  (
281
272
  f"cross_key_value_states_{i}",
282
273
  [
283
- rbln_batch_size,
274
+ rbln_config.batch_size,
284
275
  n_head,
285
- rbln_enc_max_seq_len,
276
+ rbln_config.enc_max_seq_len,
286
277
  d_kv,
287
278
  ],
288
279
  "float32",
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
295
286
  (
296
287
  f"self_key_value_states_{i}",
297
288
  [
298
- rbln_batch_size,
289
+ rbln_config.batch_size,
299
290
  n_head,
300
- rbln_dec_max_seq_len,
291
+ rbln_config.dec_max_seq_len,
301
292
  d_kv,
302
293
  ],
303
294
  "float32",
@@ -306,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
306
297
  ]
307
298
  )
308
299
 
309
- if rbln_use_attention_mask:
310
- dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
300
+ if rbln_config.use_attention_mask:
301
+ dec_input_info.insert(
302
+ 1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
303
+ )
311
304
 
312
305
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
313
306
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
314
307
 
315
- rbln_config = RBLNConfig(
316
- rbln_cls=cls.__name__,
317
- compile_cfgs=[enc_compile_config, dec_compile_config],
318
- rbln_kwargs=rbln_kwargs,
319
- )
320
-
321
- rbln_config.model_cfg.update(
322
- {
323
- "enc_max_seq_len": rbln_enc_max_seq_len,
324
- "dec_max_seq_len": rbln_dec_max_seq_len,
325
- "batch_size": rbln_batch_size,
326
- "pad_token_id": rbln_pad_token_id,
327
- "use_attention_mask": rbln_use_attention_mask,
328
- }
329
- )
330
-
308
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
331
309
  return rbln_config
332
310
 
333
311
  @classmethod
334
312
  def _create_runtimes(
335
313
  cls,
336
314
  compiled_models: List[rebel.RBLNCompiledModel],
337
- rbln_device_map: Dict[str, int],
338
- activate_profiler: Optional[bool] = None,
315
+ rbln_config: RBLNModelForSeq2SeqLMConfig,
339
316
  ) -> List[rebel.Runtime]:
340
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
317
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
341
318
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
342
319
 
343
320
  return [
344
- compiled_models[0].create_runtime(
345
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
321
+ rebel.Runtime(
322
+ compiled_models[0],
323
+ tensor_type="pt",
324
+ device=rbln_config.device_map["encoder"],
325
+ activate_profiler=rbln_config.activate_profiler,
346
326
  ),
347
- compiled_models[1].create_runtime(
348
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
327
+ rebel.Runtime(
328
+ compiled_models[1],
329
+ tensor_type="pt",
330
+ device=rbln_config.device_map["decoder"],
331
+ activate_profiler=rbln_config.activate_profiler,
349
332
  ),
350
333
  ]
351
334
 
@@ -367,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
367
350
  ):
368
351
  cur_seq_len = input_ids.shape[-1]
369
352
  cache_position = cur_seq_len - 1
370
- max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
353
+ max_seq_len = self.rbln_config.dec_max_seq_len
371
354
  decoder_batch_size = input_ids.shape[0]
372
355
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
373
356
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
@@ -387,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
387
370
  **kwargs,
388
371
  ) -> Tuple[torch.FloatTensor]:
389
372
  # common decoder
390
- cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
373
+ cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
391
374
  logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
392
375
 
393
376
  return Seq2SeqLMOutput(
@@ -421,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
421
404
  batch_size, input_len = inputs_tensor.shape
422
405
  inputs_tensor = torch.nn.functional.pad(
423
406
  inputs_tensor,
424
- (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
425
- value=self.rbln_config.model_cfg["pad_token_id"],
407
+ (0, self.rbln_config.enc_max_seq_len - input_len),
408
+ value=self.rbln_config.pad_token_id,
426
409
  )
427
410
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
428
- model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
411
+ model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
429
412
  )
430
413
 
431
414
  # 3. make sure that encoder returns `ModelOutput`
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
16
17
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass