optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..decoderonly.decoderonly_architecture import (
8
+ DecoderOnlyWrapper,
9
+ apply_rotary_pos_emb,
10
+ )
11
+
12
+
13
+ class Qwen2_5_VisionTransformerWrapper(nn.Module):
14
+ def __init__(self, model: torch.nn.Module):
15
+ super().__init__()
16
+ self._original_mod = model
17
+ self.fullatt_block_indexes = model.fullatt_block_indexes
18
+ self.merger = model.merger
19
+ window_seq_len = (model.window_size // model.patch_size) ** 2
20
+ self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
21
+
22
+ def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
23
+ wrapped_blocks = []
24
+ for i, block in enumerate(blocks):
25
+ is_full_attn = True if i in self.fullatt_block_indexes else False
26
+ wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
27
+ return nn.ModuleList(wrapped_blocks)
28
+
29
+ def forward(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ full_attn_masks: torch.Tensor,
33
+ window_attn_masks: torch.Tensor,
34
+ cos: torch.Tensor,
35
+ sin: torch.Tensor,
36
+ ):
37
+ full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
38
+ window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
39
+
40
+ for i, block in enumerate(self.blocks):
41
+ attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
42
+ hidden_states = block(hidden_states, attn_masks, [cos, sin])
43
+
44
+ hidden_states = self.merger(hidden_states)
45
+
46
+ return hidden_states
47
+
48
+
49
+ class Qwen2_5_VLVisionBlock(torch.nn.Module):
50
+ def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
51
+ super().__init__()
52
+ self._origin_model = model
53
+ self.norm1 = model.norm1
54
+ self.norm2 = model.norm2
55
+
56
+ if is_full_attn:
57
+ self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
58
+ else:
59
+ self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
60
+ self.mlp = model.mlp
61
+
62
+ def forward(
63
+ self,
64
+ hidden_states: torch.Tensor,
65
+ attn_masks: torch.Tensor,
66
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
67
+ ) -> torch.Tensor:
68
+ hidden_states = hidden_states + self.attn(
69
+ self.norm1(hidden_states),
70
+ attn_masks,
71
+ position_embeddings,
72
+ )
73
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
74
+ return hidden_states
75
+
76
+
77
+ class Qwen2_5_VLVisionFullAttention(nn.Module):
78
+ def __init__(self, model: nn.Module) -> None:
79
+ super().__init__()
80
+ self._origin_model = model
81
+ self.num_heads = model.num_heads
82
+ self.head_dim = model.head_dim
83
+ self.qkv = model.qkv
84
+ self.proj = model.proj
85
+
86
+ def forward(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ attn_masks: torch.Tensor,
90
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
91
+ ) -> torch.Tensor:
92
+ seq_length = hidden_states.shape[0]
93
+ hidden_states = hidden_states.unsqueeze(0)
94
+ q, k, v = (
95
+ self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
96
+ )
97
+
98
+ cos, sin = position_embeddings
99
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
100
+
101
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
102
+ attn_weights = attn_weights + attn_masks
103
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
104
+ attn_output = torch.matmul(attn_weights, v)
105
+ attn_output = attn_output.transpose(1, 2)
106
+ attn_output = attn_output.reshape(1, seq_length, -1)
107
+ attn_output = self.proj(attn_output).squeeze(0)
108
+
109
+ return attn_output
110
+
111
+
112
+ class Qwen2_5_VLVisionWindowAttention(nn.Module):
113
+ def __init__(self, model: nn.Module, window_seq_len: int) -> None:
114
+ super().__init__()
115
+ self._origin_model = model
116
+ self.num_heads = model.num_heads
117
+ self.head_dim = model.head_dim
118
+ self.qkv = model.qkv
119
+ self.proj = model.proj
120
+ self.window_seq_len = window_seq_len
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ attn_masks: torch.Tensor,
126
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
127
+ ) -> torch.Tensor:
128
+ seq_length = hidden_states.shape[0]
129
+ num_windows = seq_length // self.window_seq_len
130
+
131
+ window_hidden_states = []
132
+ for i in range(0, seq_length, self.window_seq_len):
133
+ window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
134
+ hidden_states = torch.stack(window_hidden_states)
135
+
136
+ q, k, v = (
137
+ self.qkv(hidden_states)
138
+ .reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
139
+ .permute(2, 0, 3, 1, 4)
140
+ .unbind(0)
141
+ )
142
+ cos, sin = position_embeddings
143
+ cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
144
+ sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
145
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
146
+
147
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
148
+
149
+ attn_weights = attn_weights + attn_masks
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
151
+ attn_output = torch.matmul(attn_weights, v)
152
+ attn_output = attn_output.transpose(1, 2)
153
+ attn_output = attn_output.reshape(1, seq_length, -1)
154
+ attn_output = self.proj(attn_output).squeeze(0)
155
+
156
+ return attn_output
157
+
158
+
159
+ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
160
+ def forward(self, *args):
161
+ if self.phase == "decode":
162
+ if self.use_attention_mask:
163
+ (
164
+ input_ids_or_inputs_embeds,
165
+ cache_position,
166
+ attention_mask,
167
+ block_tables,
168
+ position_emb,
169
+ *past_key_values,
170
+ ) = args
171
+ else:
172
+ (
173
+ input_ids_or_inputs_embeds,
174
+ cache_position,
175
+ block_tables,
176
+ position_emb,
177
+ *past_key_values,
178
+ ) = args
179
+ attention_mask = None
180
+ query_position = None
181
+ elif self.phase == "prefill":
182
+ if self.use_attention_mask:
183
+ (
184
+ input_ids_or_inputs_embeds,
185
+ cache_position,
186
+ attention_mask,
187
+ query_position,
188
+ block_tables,
189
+ position_emb,
190
+ *past_key_values,
191
+ ) = args
192
+ else:
193
+ (
194
+ input_ids_or_inputs_embeds,
195
+ cache_position,
196
+ query_position,
197
+ block_tables,
198
+ position_emb,
199
+ *past_key_values,
200
+ ) = args
201
+ attention_mask = None
202
+
203
+ else:
204
+ raise ValueError(f"Unknown phase: {self.phase}")
205
+
206
+ return self.forward_common(
207
+ input_ids_or_inputs_embeds,
208
+ cache_position,
209
+ attention_mask,
210
+ query_position,
211
+ block_tables,
212
+ position_emb,
213
+ *past_key_values,
214
+ )
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
15
16
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -0,0 +1,66 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import rebel
18
+
19
+ from ....configuration_utils import RBLNModelConfig
20
+ from ....utils.logging import get_logger
21
+
22
+
23
+ logger = get_logger()
24
+
25
+
26
+ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ enc_max_seq_len: Optional[int] = None,
31
+ dec_max_seq_len: Optional[int] = None,
32
+ use_attention_mask: Optional[bool] = None,
33
+ pad_token_id: Optional[int] = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Args:
38
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
39
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
+ use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
+ This is automatically set to True for RBLN-CA02 devices.
43
+ pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+
46
+ Raises:
47
+ ValueError: If batch_size is not a positive integer.
48
+ """
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
52
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
53
+
54
+ self.enc_max_seq_len = enc_max_seq_len
55
+ self.dec_max_seq_len = dec_max_seq_len
56
+
57
+ self.use_attention_mask = use_attention_mask
58
+ npu = self.npu or rebel.get_npu_name()
59
+ if npu == "RBLN-CA02":
60
+ if self.use_attention_mask is False:
61
+ logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
+ self.use_attention_mask = True
63
+ else:
64
+ self.use_attention_mask = self.use_attention_mask or False
65
+
66
+ self.pad_token_id = pad_token_id
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
+ from ....configuration_utils import RBLNCompileConfig
25
26
  from ....modeling import RBLNModel
26
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
27
27
  from ....utils.logging import get_logger
28
28
  from ....utils.runtime_utils import RBLNPytorchRuntime
29
+ from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
29
30
 
30
31
 
31
32
  logger = get_logger(__name__)
@@ -38,8 +39,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
38
39
  mandatory_members = ["main_input_name"]
39
40
 
40
41
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
- _ = super().forward(*args, **kwargs)
42
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
42
+ output = super().forward(*args, **kwargs)
43
+ return BaseModelOutput(last_hidden_state=output)
43
44
 
44
45
 
45
46
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -50,7 +51,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
51
  runtime: rebel.Runtime,
51
52
  batch_size: int,
52
53
  dec_max_seq_len: int,
53
- support_paged_causal_attn: Optional[bool] = None,
54
54
  use_attention_mask: Optional[bool] = None,
55
55
  **kwargs: Any,
56
56
  ) -> None:
@@ -58,10 +58,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
58
58
  self.batch_size = batch_size
59
59
  self.dec_max_seq_len = dec_max_seq_len
60
60
  self.use_attention_mask = use_attention_mask
61
- if support_paged_causal_attn:
62
- self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
63
- else:
64
- self.default_block_tables = None
61
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
65
62
 
66
63
  def forward(
67
64
  self,
@@ -119,12 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
119
116
 
120
117
  main_input_name = "input_ids"
121
118
  auto_model_class = AutoModelForSeq2SeqLM
122
- support_paged_causal_attn = None
119
+ support_causal_attn = None
123
120
 
124
121
  def __post_init__(self, **kwargs):
125
- batch_size = self.rbln_config.model_cfg["batch_size"]
126
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
127
- self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
122
+ batch_size = self.rbln_config.batch_size
123
+ dec_max_seq_len = self.rbln_config.dec_max_seq_len
124
+ self.use_attention_mask = self.rbln_config.use_attention_mask
128
125
 
129
126
  self.encoder = RBLNRuntimeEncoder(
130
127
  runtime=self.model[0],
@@ -135,13 +132,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
135
132
  main_input_name="input_ids",
136
133
  batch_size=batch_size,
137
134
  dec_max_seq_len=dec_max_seq_len,
138
- support_paged_causal_attn=self.support_paged_causal_attn,
139
135
  use_attention_mask=self.use_attention_mask,
140
136
  )
141
137
 
142
138
  @classmethod
143
139
  @torch.inference_mode()
144
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
140
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
145
141
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
146
142
 
147
143
  enc_compile_config = rbln_config.compile_cfgs[0]
@@ -182,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
182
178
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
179
 
184
180
  @classmethod
185
- def _get_rbln_config(
181
+ def _update_rbln_config(
186
182
  cls,
187
183
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
188
- model_config: "PretrainedConfig",
189
- rbln_kwargs: Dict[str, Any] = {},
190
- ) -> RBLNConfig:
191
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
192
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
193
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
194
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
195
-
196
- if cls.support_paged_causal_attn:
197
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
198
- if rbln_use_attention_mask is None:
199
- rbln_use_attention_mask = False
200
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
201
- if rbln_npu == "RBLN-CA02":
202
- rbln_use_attention_mask = True
203
- else:
204
- rbln_use_attention_mask = True
184
+ model: Optional["PreTrainedModel"] = None,
185
+ model_config: Optional["PretrainedConfig"] = None,
186
+ rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
187
+ ) -> RBLNModelForSeq2SeqLMConfig:
188
+ if not cls.support_causal_attn:
189
+ rbln_config.use_attention_mask = True
205
190
 
206
191
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
207
192
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -215,79 +200,85 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
215
200
  model_config, "max_position_embeddings", None
216
201
  )
217
202
 
218
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
219
- if rbln_pad_token_id is None:
220
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
221
- if rbln_pad_token_id is None:
222
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
223
- if rbln_pad_token_id is None:
224
- rbln_pad_token_id = -1
225
-
226
- if rbln_enc_max_seq_len is None:
227
- rbln_enc_max_seq_len = max_position_embeddings
228
- if rbln_enc_max_seq_len is None:
229
- for tokenizer in preprocessors:
230
- if hasattr(tokenizer, "model_max_length"):
231
- rbln_enc_max_seq_len = tokenizer.model_max_length
232
- break
233
- if rbln_enc_max_seq_len is None:
234
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
235
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
236
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
237
-
238
- if rbln_dec_max_seq_len is None:
239
- rbln_dec_max_seq_len = max_position_embeddings
240
- if rbln_dec_max_seq_len is None:
241
- for tokenizer in preprocessors:
242
- if hasattr(tokenizer, "model_max_length"):
243
- rbln_dec_max_seq_len = tokenizer.model_max_length
244
- break
245
- if rbln_dec_max_seq_len is None:
246
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
247
-
248
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
249
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
203
+ pad_token_id = getattr(model_config, "pad_token_id", None)
204
+ pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
205
+ pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
206
+ pad_token_id = pad_token_id or -1
207
+ rbln_config.pad_token_id = pad_token_id
208
+
209
+ if rbln_config.enc_max_seq_len is None:
210
+ enc_max_seq_len = max_position_embeddings
211
+ for tokenizer in preprocessors:
212
+ if hasattr(tokenizer, "model_max_length"):
213
+ enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
214
+ break
215
+
216
+ if enc_max_seq_len is None:
217
+ raise ValueError("`enc_max_seq_len` should be specified!")
218
+ rbln_config.enc_max_seq_len = enc_max_seq_len
219
+
220
+ if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
221
+ raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
222
+
223
+ if rbln_config.dec_max_seq_len is None:
224
+ dec_max_seq_len = max_position_embeddings
225
+ for tokenizer in preprocessors:
226
+ if hasattr(tokenizer, "model_max_length"):
227
+ dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
228
+ break
229
+
230
+ if dec_max_seq_len is None:
231
+ raise ValueError("`dec_max_seq_len` should be specified!")
232
+ rbln_config.dec_max_seq_len = dec_max_seq_len
233
+
234
+ if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
235
+ raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
250
236
 
251
237
  # model input info
252
238
  enc_input_info = [
253
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
254
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
255
- (
256
- "cross_key_value_states",
257
- [
258
- n_layer * 2,
259
- rbln_batch_size,
260
- n_head,
261
- rbln_enc_max_seq_len,
262
- d_kv,
263
- ],
264
- "float32",
265
- ),
239
+ ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
240
+ ("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
266
241
  ("block_tables", [1], "int16"),
267
242
  ]
243
+ enc_input_info.extend(
244
+ [
245
+ (
246
+ f"cross_key_value_states_{i}",
247
+ [
248
+ rbln_config.batch_size,
249
+ n_head,
250
+ rbln_config.enc_max_seq_len,
251
+ d_kv,
252
+ ],
253
+ "float32",
254
+ )
255
+ for i in range(n_layer * 2)
256
+ ]
257
+ )
268
258
 
269
259
  dec_input_info = [
270
- ("input_ids", [rbln_batch_size, 1], "int64"),
271
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
260
+ ("input_ids", [rbln_config.batch_size, 1], "int64"),
261
+ ("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
272
262
  (
273
263
  "cache_position",
274
- [rbln_batch_size, 1],
264
+ [rbln_config.batch_size, 1],
275
265
  "int32",
276
266
  ),
267
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
277
268
  ]
278
269
  dec_input_info.extend(
279
270
  [
280
271
  (
281
- "cross_key_value_states",
272
+ f"cross_key_value_states_{i}",
282
273
  [
283
- n_layer * 2,
284
- rbln_batch_size,
274
+ rbln_config.batch_size,
285
275
  n_head,
286
- rbln_enc_max_seq_len,
276
+ rbln_config.enc_max_seq_len,
287
277
  d_kv,
288
278
  ],
289
279
  "float32",
290
280
  )
281
+ for i in range(n_layer * 2)
291
282
  ]
292
283
  )
293
284
  dec_input_info.extend(
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
295
286
  (
296
287
  f"self_key_value_states_{i}",
297
288
  [
298
- rbln_batch_size,
289
+ rbln_config.batch_size,
299
290
  n_head,
300
- rbln_dec_max_seq_len,
291
+ rbln_config.dec_max_seq_len,
301
292
  d_kv,
302
293
  ],
303
294
  "float32",
@@ -306,48 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
306
297
  ]
307
298
  )
308
299
 
309
- if cls.support_paged_causal_attn:
310
- dec_input_info.insert(3, ("block_tables", [rbln_batch_size, 1], "int16"))
311
- if rbln_use_attention_mask:
312
- dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
300
+ if rbln_config.use_attention_mask:
301
+ dec_input_info.insert(
302
+ 1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
303
+ )
313
304
 
314
305
  enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
315
306
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
316
307
 
317
- rbln_config = RBLNConfig(
318
- rbln_cls=cls.__name__,
319
- compile_cfgs=[enc_compile_config, dec_compile_config],
320
- rbln_kwargs=rbln_kwargs,
321
- )
322
-
323
- rbln_config.model_cfg.update(
324
- {
325
- "enc_max_seq_len": rbln_enc_max_seq_len,
326
- "dec_max_seq_len": rbln_dec_max_seq_len,
327
- "batch_size": rbln_batch_size,
328
- "pad_token_id": rbln_pad_token_id,
329
- "use_attention_mask": rbln_use_attention_mask,
330
- }
331
- )
332
-
308
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
333
309
  return rbln_config
334
310
 
335
311
  @classmethod
336
312
  def _create_runtimes(
337
313
  cls,
338
314
  compiled_models: List[rebel.RBLNCompiledModel],
339
- rbln_device_map: Dict[str, int],
340
- activate_profiler: Optional[bool] = None,
315
+ rbln_config: RBLNModelForSeq2SeqLMConfig,
341
316
  ) -> List[rebel.Runtime]:
342
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
317
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
343
318
  cls._raise_missing_compiled_file_error(["encoder", "decoder"])
344
319
 
345
320
  return [
346
- compiled_models[0].create_runtime(
347
- tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
321
+ rebel.Runtime(
322
+ compiled_models[0],
323
+ tensor_type="pt",
324
+ device=rbln_config.device_map["encoder"],
325
+ activate_profiler=rbln_config.activate_profiler,
348
326
  ),
349
- compiled_models[1].create_runtime(
350
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
327
+ rebel.Runtime(
328
+ compiled_models[1],
329
+ tensor_type="pt",
330
+ device=rbln_config.device_map["decoder"],
331
+ activate_profiler=rbln_config.activate_profiler,
351
332
  ),
352
333
  ]
353
334
 
@@ -369,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
369
350
  ):
370
351
  cur_seq_len = input_ids.shape[-1]
371
352
  cache_position = cur_seq_len - 1
372
- max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
353
+ max_seq_len = self.rbln_config.dec_max_seq_len
373
354
  decoder_batch_size = input_ids.shape[0]
374
355
  input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
375
356
  decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
@@ -389,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
389
370
  **kwargs,
390
371
  ) -> Tuple[torch.FloatTensor]:
391
372
  # common decoder
392
- cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
373
+ cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
393
374
  logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
394
375
 
395
376
  return Seq2SeqLMOutput(
@@ -423,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
423
404
  batch_size, input_len = inputs_tensor.shape
424
405
  inputs_tensor = torch.nn.functional.pad(
425
406
  inputs_tensor,
426
- (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
427
- value=self.rbln_config.model_cfg["pad_token_id"],
407
+ (0, self.rbln_config.enc_max_seq_len - input_len),
408
+ value=self.rbln_config.pad_token_id,
428
409
  )
429
410
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
430
- model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
411
+ model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
431
412
  )
432
413
 
433
414
  # 3. make sure that encoder returns `ModelOutput`