optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..decoderonly.decoderonly_architecture import (
8
+ DecoderOnlyWrapper,
9
+ apply_rotary_pos_emb,
10
+ )
11
+
12
+
13
+ class Qwen2_5_VisionTransformerWrapper(nn.Module):
14
+ def __init__(self, model: torch.nn.Module):
15
+ super().__init__()
16
+ self._original_mod = model
17
+ self.fullatt_block_indexes = model.fullatt_block_indexes
18
+ self.merger = model.merger
19
+ window_seq_len = (model.window_size // model.patch_size) ** 2
20
+ self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
21
+
22
+ def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
23
+ wrapped_blocks = []
24
+ for i, block in enumerate(blocks):
25
+ is_full_attn = True if i in self.fullatt_block_indexes else False
26
+ wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
27
+ return nn.ModuleList(wrapped_blocks)
28
+
29
+ def forward(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ full_attn_masks: torch.Tensor,
33
+ window_attn_masks: torch.Tensor,
34
+ cos: torch.Tensor,
35
+ sin: torch.Tensor,
36
+ ):
37
+ full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
38
+ window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
39
+
40
+ for i, block in enumerate(self.blocks):
41
+ attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
42
+ hidden_states = block(hidden_states, attn_masks, [cos, sin])
43
+
44
+ hidden_states = self.merger(hidden_states)
45
+
46
+ return hidden_states
47
+
48
+
49
+ class Qwen2_5_VLVisionBlock(torch.nn.Module):
50
+ def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
51
+ super().__init__()
52
+ self._origin_model = model
53
+ self.norm1 = model.norm1
54
+ self.norm2 = model.norm2
55
+
56
+ if is_full_attn:
57
+ self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
58
+ else:
59
+ self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
60
+ self.mlp = model.mlp
61
+
62
+ def forward(
63
+ self,
64
+ hidden_states: torch.Tensor,
65
+ attn_masks: torch.Tensor,
66
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
67
+ ) -> torch.Tensor:
68
+ hidden_states = hidden_states + self.attn(
69
+ self.norm1(hidden_states),
70
+ attn_masks,
71
+ position_embeddings,
72
+ )
73
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
74
+ return hidden_states
75
+
76
+
77
+ class Qwen2_5_VLVisionFullAttention(nn.Module):
78
+ def __init__(self, model: nn.Module) -> None:
79
+ super().__init__()
80
+ self._origin_model = model
81
+ self.num_heads = model.num_heads
82
+ self.head_dim = model.head_dim
83
+ self.qkv = model.qkv
84
+ self.proj = model.proj
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attn_masks: torch.Tensor,
90
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
91
+ ) -> torch.Tensor:
92
+ seq_length = hidden_states.shape[0]
93
+ hidden_states = hidden_states.unsqueeze(0)
94
+ q, k, v = (
95
+ self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
96
+ )
97
+
98
+ cos, sin = position_embeddings
99
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
100
+
101
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
102
+ attn_weights = attn_weights + attn_masks
103
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
104
+ attn_output = torch.matmul(attn_weights, v)
105
+ attn_output = attn_output.transpose(1, 2)
106
+ attn_output = attn_output.reshape(1, seq_length, -1)
107
+ attn_output = self.proj(attn_output).squeeze(0)
108
+
109
+ return attn_output
110
+
111
+
112
+ class Qwen2_5_VLVisionWindowAttention(nn.Module):
113
+ def __init__(self, model: nn.Module, window_seq_len: int) -> None:
114
+ super().__init__()
115
+ self._origin_model = model
116
+ self.num_heads = model.num_heads
117
+ self.head_dim = model.head_dim
118
+ self.qkv = model.qkv
119
+ self.proj = model.proj
120
+ self.window_seq_len = window_seq_len
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ attn_masks: torch.Tensor,
126
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
127
+ ) -> torch.Tensor:
128
+ seq_length = hidden_states.shape[0]
129
+ num_windows = seq_length // self.window_seq_len
130
+
131
+ window_hidden_states = []
132
+ for i in range(0, seq_length, self.window_seq_len):
133
+ window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
134
+ hidden_states = torch.stack(window_hidden_states)
135
+
136
+ q, k, v = (
137
+ self.qkv(hidden_states)
138
+ .reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
139
+ .permute(2, 0, 3, 1, 4)
140
+ .unbind(0)
141
+ )
142
+ cos, sin = position_embeddings
143
+ cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
144
+ sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
145
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
146
+
147
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
148
+
149
+ attn_weights = attn_weights + attn_masks
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
151
+ attn_output = torch.matmul(attn_weights, v)
152
+ attn_output = attn_output.transpose(1, 2)
153
+ attn_output = attn_output.reshape(1, seq_length, -1)
154
+ attn_output = self.proj(attn_output).squeeze(0)
155
+
156
+ return attn_output
157
+
158
+
159
+ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
160
+ def forward(self, *args):
161
+ if self.phase == "decode":
162
+ if self.use_attention_mask:
163
+ (
164
+ input_ids_or_inputs_embeds,
165
+ cache_position,
166
+ attention_mask,
167
+ block_tables,
168
+ position_emb,
169
+ *past_key_values,
170
+ ) = args
171
+ else:
172
+ (
173
+ input_ids_or_inputs_embeds,
174
+ cache_position,
175
+ block_tables,
176
+ position_emb,
177
+ *past_key_values,
178
+ ) = args
179
+ attention_mask = None
180
+ query_position = None
181
+ elif self.phase == "prefill":
182
+ if self.use_attention_mask:
183
+ (
184
+ input_ids_or_inputs_embeds,
185
+ cache_position,
186
+ attention_mask,
187
+ query_position,
188
+ block_tables,
189
+ position_emb,
190
+ *past_key_values,
191
+ ) = args
192
+ else:
193
+ (
194
+ input_ids_or_inputs_embeds,
195
+ cache_position,
196
+ query_position,
197
+ block_tables,
198
+ position_emb,
199
+ *past_key_values,
200
+ ) = args
201
+ attention_mask = None
202
+
203
+ else:
204
+ raise ValueError(f"Unknown phase: {self.phase}")
205
+
206
+ return self.forward_common(
207
+ input_ids_or_inputs_embeds,
208
+ cache_position,
209
+ attention_mask,
210
+ query_position,
211
+ block_tables,
212
+ position_emb,
213
+ *past_key_values,
214
+ )
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
15
16
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+
22
+
23
+ logger = get_logger()
24
+
25
+
26
+ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ enc_max_seq_len: Optional[int] = None,
31
+ dec_max_seq_len: Optional[int] = None,
32
+ use_attention_mask: Optional[bool] = None,
33
+ pad_token_id: Optional[int] = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Args:
38
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
39
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
+ use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
+ This is automatically set to True for RBLN-CA02 devices.
43
+ pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+
46
+ Raises:
47
+ ValueError: If batch_size is not a positive integer.
48
+ """
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
52
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
53
+
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False
65
+
66
+ self.pad_token_id = pad_token_id
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
+ from ....configuration_utils import RBLNCompileConfig
25
26
  from ....modeling import RBLNModel
26
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
27
27
  from ....utils.logging import get_logger
28
28
  from ....utils.runtime_utils import RBLNPytorchRuntime
29
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
29
30
 
30
31
 
31
32
  logger = get_logger(__name__)
@@ -38,8 +39,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
38
39
  mandatory_members = ["main_input_name"]
39
40
 
40
41
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
- _ = super().forward(*args, **kwargs)
42
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
42
+ output = super().forward(*args, **kwargs)
43
+ return BaseModelOutput(last_hidden_state=output)
43
44
 
44
45
 
45
46
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -94,7 +95,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
94
95
  decoder_attention_mask if self.use_attention_mask else None,
95
96
  attention_mask,
96
97
  cache_position,
97
- block_tables,
98
+ block_tables=block_tables,
98
99
  )
99
100
 
100
101
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -115,11 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
115
116
 
116
117
  main_input_name = "input_ids"
117
118
  auto_model_class = AutoModelForSeq2SeqLM
119
+ support_causal_attn = None
118
120
 
119
121
  def __post_init__(self, **kwargs):
120
- batch_size = self.rbln_config.model_cfg["batch_size"]
121
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
122
- self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
122
+ batch_size = self.rbln_config.batch_size
123
+ dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.use_attention_mask = self.rbln_config.use_attention_mask
123
125
 
124
126
  self.encoder = RBLNRuntimeEncoder(
125
127
  runtime=self.model[0],
@@ -135,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
135
137
 
136
138
  @classmethod
137
139
  @torch.inference_mode()
138
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
140
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
139
141
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
140
142
 
141
143
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -176,23 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
176
178
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
177
179
 
178
180
  @classmethod
179
- def _get_rbln_config(
181
+ def _update_rbln_config(
180
182
  cls,
181
183
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
182
- model_config: "PretrainedConfig",
183
- rbln_kwargs: Dict[str, Any] = {},
184
- ) -> RBLNConfig:
185
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
186
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
187
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
-
191
- if rbln_use_attention_mask is None:
192
- rbln_use_attention_mask = False
193
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
- if rbln_npu == "RBLN-CA02":
195
- rbln_use_attention_mask = True
184
+ model: Optional["PreTrainedModel"] = None,
185
+ model_config: Optional["PretrainedConfig"] = None,
186
+ rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
187
+ ) -> RBLNModelForSeq2SeqLMConfig:
188
+ if not cls.support_causal_attn:
189
+ rbln_config.use_attention_mask = True
196
190
 
197
191
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
198
192
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -206,84 +200,85 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
206
200
  model_config, "max_position_embeddings", None
207
201
  )
208
202
 
209
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
210
- if rbln_pad_token_id is None:
211
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
212
- if rbln_pad_token_id is None:
213
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
214
- if rbln_pad_token_id is None:
215
- rbln_pad_token_id = -1
216
-
217
- if rbln_enc_max_seq_len is None:
218
- rbln_enc_max_seq_len = max_position_embeddings
219
- if rbln_enc_max_seq_len is None:
220
- for tokenizer in preprocessors:
221
- if hasattr(tokenizer, "model_max_length"):
222
- rbln_enc_max_seq_len = tokenizer.model_max_length
223
- break
224
- if rbln_enc_max_seq_len is None:
225
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
226
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
227
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
228
-
229
- if rbln_dec_max_seq_len is None:
230
- rbln_dec_max_seq_len = max_position_embeddings
231
- if rbln_dec_max_seq_len is None:
232
- for tokenizer in preprocessors:
233
- if hasattr(tokenizer, "model_max_length"):
234
- rbln_dec_max_seq_len = tokenizer.model_max_length
235
- break
236
- if rbln_dec_max_seq_len is None:
237
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
238
-
239
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
240
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
203
+ pad_token_id = getattr(model_config, "pad_token_id", None)
204
+ pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
205
+ pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
206
+ pad_token_id = pad_token_id or -1
207
+ rbln_config.pad_token_id = pad_token_id
208
+
209
+ if rbln_config.enc_max_seq_len is None:
210
+ enc_max_seq_len = max_position_embeddings
211
+ for tokenizer in preprocessors:
212
+ if hasattr(tokenizer, "model_max_length"):
213
+ enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
214
+ break
215
+
216
+ if enc_max_seq_len is None:
217
+ raise ValueError("`enc_max_seq_len` should be specified!")
218
+ rbln_config.enc_max_seq_len = enc_max_seq_len
219
+
220
+ if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
221
+ raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
222
+
223
+ if rbln_config.dec_max_seq_len is None:
224
+ dec_max_seq_len = max_position_embeddings
225
+ for tokenizer in preprocessors:
226
+ if hasattr(tokenizer, "model_max_length"):
227
+ dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
228
+ break
229
+
230
+ if dec_max_seq_len is None:
231
+ raise ValueError("`dec_max_seq_len` should be specified!")
232
+ rbln_config.dec_max_seq_len = dec_max_seq_len
233
+
234
+ if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
235
+ raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
241
236
 
242
237
  # model input info
243
238
  enc_input_info = [
244
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
245
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
246
- (
247
- "cross_key_value_states",
248
- [
249
- n_layer * 2,
250
- rbln_batch_size,
251
- n_head,
252
- rbln_enc_max_seq_len,
253
- d_kv,
254
- ],
255
- "float32",
256
- ),
239
+ ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
240
+ ("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
257
241
  ("block_tables", [1], "int16"),
258
242
  ]
243
+ enc_input_info.extend(
244
+ [
245
+ (
246
+ f"cross_key_value_states_{i}",
247
+ [
248
+ rbln_config.batch_size,
249
+ n_head,
250
+ rbln_config.enc_max_seq_len,
251
+ d_kv,
252
+ ],
253
+ "float32",
254
+ )
255
+ for i in range(n_layer * 2)
256
+ ]
257
+ )
259
258
 
260
259
  dec_input_info = [
261
- ("input_ids", [rbln_batch_size, 1], "int64"),
262
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
260
+ ("input_ids", [rbln_config.batch_size, 1], "int64"),
261
+ ("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
263
262
  (
264
263
  "cache_position",
265
- [rbln_batch_size, 1],
264
+ [rbln_config.batch_size, 1],
266
265
  "int32",
267
266
  ),
268
- (
269
- "block_tables",
270
- [rbln_batch_size, 1],
271
- "int16",
272
- ),
267
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
273
268
  ]
274
269
  dec_input_info.extend(
275
270
  [
276
271
  (
277
- "cross_key_value_states",
272
+ f"cross_key_value_states_{i}",
278
273
  [
279
- n_layer * 2,
280
- rbln_batch_size,
274
+ rbln_config.batch_size,
281
275
  n_head,
282
- rbln_enc_max_seq_len,
276
+ rbln_config.enc_max_seq_len,
283
277
  d_kv,
284
278
  ],
285
279
  "float32",
286
280
  )
281
+ for i in range(n_layer * 2)
287
282
  ]
288
283
  )
289
284
  dec_input_info.extend(
@@ -291,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
291
286
  (
292
287
  f"self_key_value_states_{i}",
293
288
  [
294
- rbln_batch_size,
289
+ rbln_config.batch_size,
295
290
  n_head,
296
- rbln_dec_max_seq_len,
291
+ rbln_config.dec_max_seq_len,
297
292
  d_kv,
298
293
  ],
299
294
  "float32",
@@ -302,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
302
297
  ]
303
298
  )
304
299
 
305
- if rbln_use_attention_mask:
306
- dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
300
+ if rbln_config.use_attention_mask:
301
+ dec_input_info.insert(
302
+ 1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
303
+ )
307
304
 
308
305
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
309
306
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
310
307
 
311
- rbln_config = RBLNConfig(
312
- rbln_cls=cls.__name__,
313
- compile_cfgs=[enc_compile_config, dec_compile_config],
314
- rbln_kwargs=rbln_kwargs,
315
- )
316
-
317
- rbln_config.model_cfg.update(
318
- {
319
- "enc_max_seq_len": rbln_enc_max_seq_len,
320
- "dec_max_seq_len": rbln_dec_max_seq_len,
321
- "batch_size": rbln_batch_size,
322
- "pad_token_id": rbln_pad_token_id,
323
- "use_attention_mask": rbln_use_attention_mask,
324
- }
325
- )
326
-
308
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
327
309
  return rbln_config
328
310
 
329
311
  @classmethod
330
312
  def _create_runtimes(
331
313
  cls,
332
314
  compiled_models: List[rebel.RBLNCompiledModel],
333
- rbln_device_map: Dict[str, int],
334
- activate_profiler: Optional[bool] = None,
315
+ rbln_config: RBLNModelForSeq2SeqLMConfig,
335
316
  ) -> List[rebel.Runtime]:
336
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
317
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
337
318
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
338
319
 
339
320
  return [
340
- compiled_models[0].create_runtime(
341
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
321
+ rebel.Runtime(
322
+ compiled_models[0],
323
+ tensor_type="pt",
324
+ device=rbln_config.device_map["encoder"],
325
+ activate_profiler=rbln_config.activate_profiler,
342
326
  ),
343
- compiled_models[1].create_runtime(
344
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
327
+ rebel.Runtime(
328
+ compiled_models[1],
329
+ tensor_type="pt",
330
+ device=rbln_config.device_map["decoder"],
331
+ activate_profiler=rbln_config.activate_profiler,
345
332
  ),
346
333
  ]
347
334
 
@@ -363,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
363
350
  ):
364
351
  cur_seq_len = input_ids.shape[-1]
365
352
  cache_position = cur_seq_len - 1
366
- max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
353
+ max_seq_len = self.rbln_config.dec_max_seq_len
367
354
  decoder_batch_size = input_ids.shape[0]
368
355
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
369
356
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
@@ -383,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
383
370
  **kwargs,
384
371
  ) -> Tuple[torch.FloatTensor]:
385
372
  # common decoder
386
- cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
373
+ cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
387
374
  logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
388
375
 
389
376
  return Seq2SeqLMOutput(
@@ -417,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
417
404
  batch_size, input_len = inputs_tensor.shape
418
405
  inputs_tensor = torch.nn.functional.pad(
419
406
  inputs_tensor,
420
- (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
421
- value=self.rbln_config.model_cfg["pad_token_id"],
407
+ (0, self.rbln_config.enc_max_seq_len - input_len),
408
+ value=self.rbln_config.pad_token_id,
422
409
  )
423
410
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
424
- model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
411
+ model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
425
412
  )
426
413
 
427
414
  # 3. make sure that encoder returns `ModelOutput`