optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..decoderonly.decoderonly_architecture import (
8
+ DecoderOnlyWrapper,
9
+ apply_rotary_pos_emb,
10
+ )
11
+
12
+
13
+ class Qwen2_5_VisionTransformerWrapper(nn.Module):
14
+ def __init__(self, model: torch.nn.Module):
15
+ super().__init__()
16
+ self._original_mod = model
17
+ self.fullatt_block_indexes = model.fullatt_block_indexes
18
+ self.merger = model.merger
19
+ window_seq_len = (model.window_size // model.patch_size) ** 2
20
+ self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
21
+
22
+ def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
23
+ wrapped_blocks = []
24
+ for i, block in enumerate(blocks):
25
+ is_full_attn = True if i in self.fullatt_block_indexes else False
26
+ wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
27
+ return nn.ModuleList(wrapped_blocks)
28
+
29
+ def forward(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ full_attn_masks: torch.Tensor,
33
+ window_attn_masks: torch.Tensor,
34
+ cos: torch.Tensor,
35
+ sin: torch.Tensor,
36
+ ):
37
+ full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
38
+ window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
39
+
40
+ for i, block in enumerate(self.blocks):
41
+ attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
42
+ hidden_states = block(hidden_states, attn_masks, [cos, sin])
43
+
44
+ hidden_states = self.merger(hidden_states)
45
+
46
+ return hidden_states
47
+
48
+
49
+ class Qwen2_5_VLVisionBlock(torch.nn.Module):
50
+ def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
51
+ super().__init__()
52
+ self._origin_model = model
53
+ self.norm1 = model.norm1
54
+ self.norm2 = model.norm2
55
+
56
+ if is_full_attn:
57
+ self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
58
+ else:
59
+ self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
60
+ self.mlp = model.mlp
61
+
62
+ def forward(
63
+ self,
64
+ hidden_states: torch.Tensor,
65
+ attn_masks: torch.Tensor,
66
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
67
+ ) -> torch.Tensor:
68
+ hidden_states = hidden_states + self.attn(
69
+ self.norm1(hidden_states),
70
+ attn_masks,
71
+ position_embeddings,
72
+ )
73
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
74
+ return hidden_states
75
+
76
+
77
+ class Qwen2_5_VLVisionFullAttention(nn.Module):
78
+ def __init__(self, model: nn.Module) -> None:
79
+ super().__init__()
80
+ self._origin_model = model
81
+ self.num_heads = model.num_heads
82
+ self.head_dim = model.head_dim
83
+ self.qkv = model.qkv
84
+ self.proj = model.proj
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attn_masks: torch.Tensor,
90
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
91
+ ) -> torch.Tensor:
92
+ seq_length = hidden_states.shape[0]
93
+ hidden_states = hidden_states.unsqueeze(0)
94
+ q, k, v = (
95
+ self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
96
+ )
97
+
98
+ cos, sin = position_embeddings
99
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
100
+
101
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
102
+ attn_weights = attn_weights + attn_masks
103
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
104
+ attn_output = torch.matmul(attn_weights, v)
105
+ attn_output = attn_output.transpose(1, 2)
106
+ attn_output = attn_output.reshape(1, seq_length, -1)
107
+ attn_output = self.proj(attn_output).squeeze(0)
108
+
109
+ return attn_output
110
+
111
+
112
+ class Qwen2_5_VLVisionWindowAttention(nn.Module):
113
+ def __init__(self, model: nn.Module, window_seq_len: int) -> None:
114
+ super().__init__()
115
+ self._origin_model = model
116
+ self.num_heads = model.num_heads
117
+ self.head_dim = model.head_dim
118
+ self.qkv = model.qkv
119
+ self.proj = model.proj
120
+ self.window_seq_len = window_seq_len
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ attn_masks: torch.Tensor,
126
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
127
+ ) -> torch.Tensor:
128
+ seq_length = hidden_states.shape[0]
129
+ num_windows = seq_length // self.window_seq_len
130
+
131
+ window_hidden_states = []
132
+ for i in range(0, seq_length, self.window_seq_len):
133
+ window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
134
+ hidden_states = torch.stack(window_hidden_states)
135
+
136
+ q, k, v = (
137
+ self.qkv(hidden_states)
138
+ .reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
139
+ .permute(2, 0, 3, 1, 4)
140
+ .unbind(0)
141
+ )
142
+ cos, sin = position_embeddings
143
+ cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
144
+ sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
145
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
146
+
147
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
148
+
149
+ attn_weights = attn_weights + attn_masks
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
151
+ attn_output = torch.matmul(attn_weights, v)
152
+ attn_output = attn_output.transpose(1, 2)
153
+ attn_output = attn_output.reshape(1, seq_length, -1)
154
+ attn_output = self.proj(attn_output).squeeze(0)
155
+
156
+ return attn_output
157
+
158
+
159
+ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
160
+ def forward(self, *args):
161
+ if self.phase == "decode":
162
+ if self.use_attention_mask:
163
+ (
164
+ input_ids_or_inputs_embeds,
165
+ cache_position,
166
+ attention_mask,
167
+ block_tables,
168
+ position_emb,
169
+ *past_key_values,
170
+ ) = args
171
+ else:
172
+ (
173
+ input_ids_or_inputs_embeds,
174
+ cache_position,
175
+ block_tables,
176
+ position_emb,
177
+ *past_key_values,
178
+ ) = args
179
+ attention_mask = None
180
+ query_position = None
181
+ elif self.phase == "prefill":
182
+ if self.use_attention_mask:
183
+ (
184
+ input_ids_or_inputs_embeds,
185
+ cache_position,
186
+ attention_mask,
187
+ query_position,
188
+ block_tables,
189
+ position_emb,
190
+ *past_key_values,
191
+ ) = args
192
+ else:
193
+ (
194
+ input_ids_or_inputs_embeds,
195
+ cache_position,
196
+ query_position,
197
+ block_tables,
198
+ position_emb,
199
+ *past_key_values,
200
+ ) = args
201
+ attention_mask = None
202
+
203
+ else:
204
+ raise ValueError(f"Unknown phase: {self.phase}")
205
+
206
+ return self.forward_common(
207
+ input_ids_or_inputs_embeds,
208
+ cache_position,
209
+ attention_mask,
210
+ query_position,
211
+ block_tables,
212
+ position_emb,
213
+ *past_key_values,
214
+ )
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
15
16
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+
22
+
23
+ logger = get_logger()
24
+
25
+
26
+ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ enc_max_seq_len: Optional[int] = None,
31
+ dec_max_seq_len: Optional[int] = None,
32
+ use_attention_mask: Optional[bool] = None,
33
+ pad_token_id: Optional[int] = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Args:
38
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
39
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
+ use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
+ This is automatically set to True for RBLN-CA02 devices.
43
+ pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+
46
+ Raises:
47
+ ValueError: If batch_size is not a positive integer.
48
+ """
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
52
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
53
+
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False
65
+
66
+ self.pad_token_id = pad_token_id
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
+ from ....configuration_utils import RBLNCompileConfig
25
26
  from ....modeling import RBLNModel
26
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
27
27
  from ....utils.logging import get_logger
28
28
  from ....utils.runtime_utils import RBLNPytorchRuntime
29
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
29
30
 
30
31
 
31
32
  logger = get_logger(__name__)
@@ -118,9 +119,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
118
119
  support_causal_attn = None
119
120
 
120
121
  def __post_init__(self, **kwargs):
121
- batch_size = self.rbln_config.model_cfg["batch_size"]
122
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
123
- self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
122
+ batch_size = self.rbln_config.batch_size
123
+ dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.use_attention_mask = self.rbln_config.use_attention_mask
124
125
 
125
126
  self.encoder = RBLNRuntimeEncoder(
126
127
  runtime=self.model[0],
@@ -136,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
136
137
 
137
138
  @classmethod
138
139
  @torch.inference_mode()
139
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
140
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
140
141
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
141
142
 
142
143
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -177,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
177
178
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
178
179
 
179
180
  @classmethod
180
- def _get_rbln_config(
181
+ def _update_rbln_config(
181
182
  cls,
182
183
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
183
- model_config: "PretrainedConfig",
184
- rbln_kwargs: Dict[str, Any] = {},
185
- ) -> RBLNConfig:
186
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
187
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
188
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
189
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
190
-
191
- if cls.support_causal_attn:
192
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
193
- if rbln_use_attention_mask is None:
194
- rbln_use_attention_mask = False
195
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
196
- if rbln_npu == "RBLN-CA02":
197
- rbln_use_attention_mask = True
198
- else:
199
- rbln_use_attention_mask = True
184
+ model: Optional["PreTrainedModel"] = None,
185
+ model_config: Optional["PretrainedConfig"] = None,
186
+ rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
187
+ ) -> RBLNModelForSeq2SeqLMConfig:
188
+ if not cls.support_causal_attn:
189
+ rbln_config.use_attention_mask = True
200
190
 
201
191
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
202
192
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -210,43 +200,44 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
210
200
  model_config, "max_position_embeddings", None
211
201
  )
212
202
 
213
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
214
- if rbln_pad_token_id is None:
215
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
216
- if rbln_pad_token_id is None:
217
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
218
- if rbln_pad_token_id is None:
219
- rbln_pad_token_id = -1
220
-
221
- if rbln_enc_max_seq_len is None:
222
- rbln_enc_max_seq_len = max_position_embeddings
223
- if rbln_enc_max_seq_len is None:
224
- for tokenizer in preprocessors:
225
- if hasattr(tokenizer, "model_max_length"):
226
- rbln_enc_max_seq_len = tokenizer.model_max_length
227
- break
228
- if rbln_enc_max_seq_len is None:
229
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
230
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
231
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
232
-
233
- if rbln_dec_max_seq_len is None:
234
- rbln_dec_max_seq_len = max_position_embeddings
235
- if rbln_dec_max_seq_len is None:
236
- for tokenizer in preprocessors:
237
- if hasattr(tokenizer, "model_max_length"):
238
- rbln_dec_max_seq_len = tokenizer.model_max_length
239
- break
240
- if rbln_dec_max_seq_len is None:
241
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
242
-
243
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
244
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
203
+ pad_token_id = getattr(model_config, "pad_token_id", None)
204
+ pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
205
+ pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
206
+ pad_token_id = pad_token_id or -1
207
+ rbln_config.pad_token_id = pad_token_id
208
+
209
+ if rbln_config.enc_max_seq_len is None:
210
+ enc_max_seq_len = max_position_embeddings
211
+ for tokenizer in preprocessors:
212
+ if hasattr(tokenizer, "model_max_length"):
213
+ enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
214
+ break
215
+
216
+ if enc_max_seq_len is None:
217
+ raise ValueError("`enc_max_seq_len` should be specified!")
218
+ rbln_config.enc_max_seq_len = enc_max_seq_len
219
+
220
+ if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
221
+ raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
222
+
223
+ if rbln_config.dec_max_seq_len is None:
224
+ dec_max_seq_len = max_position_embeddings
225
+ for tokenizer in preprocessors:
226
+ if hasattr(tokenizer, "model_max_length"):
227
+ dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
228
+ break
229
+
230
+ if dec_max_seq_len is None:
231
+ raise ValueError("`dec_max_seq_len` should be specified!")
232
+ rbln_config.dec_max_seq_len = dec_max_seq_len
233
+
234
+ if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
235
+ raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
245
236
 
246
237
  # model input info
247
238
  enc_input_info = [
248
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
249
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
239
+ ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
240
+ ("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
250
241
  ("block_tables", [1], "int16"),
251
242
  ]
252
243
  enc_input_info.extend(
@@ -254,9 +245,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
254
245
  (
255
246
  f"cross_key_value_states_{i}",
256
247
  [
257
- rbln_batch_size,
248
+ rbln_config.batch_size,
258
249
  n_head,
259
- rbln_enc_max_seq_len,
250
+ rbln_config.enc_max_seq_len,
260
251
  d_kv,
261
252
  ],
262
253
  "float32",
@@ -266,23 +257,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
266
257
  )
267
258
 
268
259
  dec_input_info = [
269
- ("input_ids", [rbln_batch_size, 1], "int64"),
270
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
260
+ ("input_ids", [rbln_config.batch_size, 1], "int64"),
261
+ ("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
271
262
  (
272
263
  "cache_position",
273
- [rbln_batch_size, 1],
264
+ [rbln_config.batch_size, 1],
274
265
  "int32",
275
266
  ),
276
- ("block_tables", [rbln_batch_size, 1], "int16"),
267
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
277
268
  ]
278
269
  dec_input_info.extend(
279
270
  [
280
271
  (
281
272
  f"cross_key_value_states_{i}",
282
273
  [
283
- rbln_batch_size,
274
+ rbln_config.batch_size,
284
275
  n_head,
285
- rbln_enc_max_seq_len,
276
+ rbln_config.enc_max_seq_len,
286
277
  d_kv,
287
278
  ],
288
279
  "float32",
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
295
286
  (
296
287
  f"self_key_value_states_{i}",
297
288
  [
298
- rbln_batch_size,
289
+ rbln_config.batch_size,
299
290
  n_head,
300
- rbln_dec_max_seq_len,
291
+ rbln_config.dec_max_seq_len,
301
292
  d_kv,
302
293
  ],
303
294
  "float32",
@@ -306,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
306
297
  ]
307
298
  )
308
299
 
309
- if rbln_use_attention_mask:
310
- dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
300
+ if rbln_config.use_attention_mask:
301
+ dec_input_info.insert(
302
+ 1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
303
+ )
311
304
 
312
305
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
313
306
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
314
307
 
315
- rbln_config = RBLNConfig(
316
- rbln_cls=cls.__name__,
317
- compile_cfgs=[enc_compile_config, dec_compile_config],
318
- rbln_kwargs=rbln_kwargs,
319
- )
320
-
321
- rbln_config.model_cfg.update(
322
- {
323
- "enc_max_seq_len": rbln_enc_max_seq_len,
324
- "dec_max_seq_len": rbln_dec_max_seq_len,
325
- "batch_size": rbln_batch_size,
326
- "pad_token_id": rbln_pad_token_id,
327
- "use_attention_mask": rbln_use_attention_mask,
328
- }
329
- )
330
-
308
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
331
309
  return rbln_config
332
310
 
333
311
  @classmethod
334
312
  def _create_runtimes(
335
313
  cls,
336
314
  compiled_models: List[rebel.RBLNCompiledModel],
337
- rbln_device_map: Dict[str, int],
338
- activate_profiler: Optional[bool] = None,
315
+ rbln_config: RBLNModelForSeq2SeqLMConfig,
339
316
  ) -> List[rebel.Runtime]:
340
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
317
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
341
318
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
342
319
 
343
320
  return [
344
- compiled_models[0].create_runtime(
345
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
321
+ rebel.Runtime(
322
+ compiled_models[0],
323
+ tensor_type="pt",
324
+ device=rbln_config.device_map["encoder"],
325
+ activate_profiler=rbln_config.activate_profiler,
346
326
  ),
347
- compiled_models[1].create_runtime(
348
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
327
+ rebel.Runtime(
328
+ compiled_models[1],
329
+ tensor_type="pt",
330
+ device=rbln_config.device_map["decoder"],
331
+ activate_profiler=rbln_config.activate_profiler,
349
332
  ),
350
333
  ]
351
334
 
@@ -367,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
367
350
  ):
368
351
  cur_seq_len = input_ids.shape[-1]
369
352
  cache_position = cur_seq_len - 1
370
- max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
353
+ max_seq_len = self.rbln_config.dec_max_seq_len
371
354
  decoder_batch_size = input_ids.shape[0]
372
355
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
373
356
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
@@ -387,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
387
370
  **kwargs,
388
371
  ) -> Tuple[torch.FloatTensor]:
389
372
  # common decoder
390
- cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
373
+ cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
391
374
  logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
392
375
 
393
376
  return Seq2SeqLMOutput(
@@ -421,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
421
404
  batch_size, input_len = inputs_tensor.shape
422
405
  inputs_tensor = torch.nn.functional.pad(
423
406
  inputs_tensor,
424
- (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
425
- value=self.rbln_config.model_cfg["pad_token_id"],
407
+ (0, self.rbln_config.enc_max_seq_len - input_len),
408
+ value=self.rbln_config.pad_token_id,
426
409
  )
427
410
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
428
- model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
411
+ model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
429
412
  )
430
413
 
431
414
  # 3. make sure that encoder returns `ModelOutput`
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
16
17
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass