optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
from ..decoderonly.decoderonly_architecture import (
|
8
|
+
DecoderOnlyWrapper,
|
9
|
+
apply_rotary_pos_emb,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class Qwen2_5_VisionTransformerWrapper(nn.Module):
|
14
|
+
def __init__(self, model: torch.nn.Module):
|
15
|
+
super().__init__()
|
16
|
+
self._original_mod = model
|
17
|
+
self.fullatt_block_indexes = model.fullatt_block_indexes
|
18
|
+
self.merger = model.merger
|
19
|
+
window_seq_len = (model.window_size // model.patch_size) ** 2
|
20
|
+
self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
|
21
|
+
|
22
|
+
def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
|
23
|
+
wrapped_blocks = []
|
24
|
+
for i, block in enumerate(blocks):
|
25
|
+
is_full_attn = True if i in self.fullatt_block_indexes else False
|
26
|
+
wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
|
27
|
+
return nn.ModuleList(wrapped_blocks)
|
28
|
+
|
29
|
+
def forward(
|
30
|
+
self,
|
31
|
+
hidden_states: torch.Tensor,
|
32
|
+
full_attn_masks: torch.Tensor,
|
33
|
+
window_attn_masks: torch.Tensor,
|
34
|
+
cos: torch.Tensor,
|
35
|
+
sin: torch.Tensor,
|
36
|
+
):
|
37
|
+
full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
|
38
|
+
window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
|
39
|
+
|
40
|
+
for i, block in enumerate(self.blocks):
|
41
|
+
attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
|
42
|
+
hidden_states = block(hidden_states, attn_masks, [cos, sin])
|
43
|
+
|
44
|
+
hidden_states = self.merger(hidden_states)
|
45
|
+
|
46
|
+
return hidden_states
|
47
|
+
|
48
|
+
|
49
|
+
class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
50
|
+
def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
|
51
|
+
super().__init__()
|
52
|
+
self._origin_model = model
|
53
|
+
self.norm1 = model.norm1
|
54
|
+
self.norm2 = model.norm2
|
55
|
+
|
56
|
+
if is_full_attn:
|
57
|
+
self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
|
58
|
+
else:
|
59
|
+
self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
|
60
|
+
self.mlp = model.mlp
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
self,
|
64
|
+
hidden_states: torch.Tensor,
|
65
|
+
attn_masks: torch.Tensor,
|
66
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
67
|
+
) -> torch.Tensor:
|
68
|
+
hidden_states = hidden_states + self.attn(
|
69
|
+
self.norm1(hidden_states),
|
70
|
+
attn_masks,
|
71
|
+
position_embeddings,
|
72
|
+
)
|
73
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
74
|
+
return hidden_states
|
75
|
+
|
76
|
+
|
77
|
+
class Qwen2_5_VLVisionFullAttention(nn.Module):
|
78
|
+
def __init__(self, model: nn.Module) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self._origin_model = model
|
81
|
+
self.num_heads = model.num_heads
|
82
|
+
self.head_dim = model.head_dim
|
83
|
+
self.qkv = model.qkv
|
84
|
+
self.proj = model.proj
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
hidden_states: torch.Tensor,
|
89
|
+
attn_masks: torch.Tensor,
|
90
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
91
|
+
) -> torch.Tensor:
|
92
|
+
seq_length = hidden_states.shape[0]
|
93
|
+
hidden_states = hidden_states.unsqueeze(0)
|
94
|
+
q, k, v = (
|
95
|
+
self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
|
96
|
+
)
|
97
|
+
|
98
|
+
cos, sin = position_embeddings
|
99
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
100
|
+
|
101
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
102
|
+
attn_weights = attn_weights + attn_masks
|
103
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
104
|
+
attn_output = torch.matmul(attn_weights, v)
|
105
|
+
attn_output = attn_output.transpose(1, 2)
|
106
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
107
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
108
|
+
|
109
|
+
return attn_output
|
110
|
+
|
111
|
+
|
112
|
+
class Qwen2_5_VLVisionWindowAttention(nn.Module):
|
113
|
+
def __init__(self, model: nn.Module, window_seq_len: int) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self._origin_model = model
|
116
|
+
self.num_heads = model.num_heads
|
117
|
+
self.head_dim = model.head_dim
|
118
|
+
self.qkv = model.qkv
|
119
|
+
self.proj = model.proj
|
120
|
+
self.window_seq_len = window_seq_len
|
121
|
+
|
122
|
+
def forward(
|
123
|
+
self,
|
124
|
+
hidden_states: torch.Tensor,
|
125
|
+
attn_masks: torch.Tensor,
|
126
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
127
|
+
) -> torch.Tensor:
|
128
|
+
seq_length = hidden_states.shape[0]
|
129
|
+
num_windows = seq_length // self.window_seq_len
|
130
|
+
|
131
|
+
window_hidden_states = []
|
132
|
+
for i in range(0, seq_length, self.window_seq_len):
|
133
|
+
window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
|
134
|
+
hidden_states = torch.stack(window_hidden_states)
|
135
|
+
|
136
|
+
q, k, v = (
|
137
|
+
self.qkv(hidden_states)
|
138
|
+
.reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
|
139
|
+
.permute(2, 0, 3, 1, 4)
|
140
|
+
.unbind(0)
|
141
|
+
)
|
142
|
+
cos, sin = position_embeddings
|
143
|
+
cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
|
144
|
+
sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
|
145
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
146
|
+
|
147
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
148
|
+
|
149
|
+
attn_weights = attn_weights + attn_masks
|
150
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
151
|
+
attn_output = torch.matmul(attn_weights, v)
|
152
|
+
attn_output = attn_output.transpose(1, 2)
|
153
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
154
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
155
|
+
|
156
|
+
return attn_output
|
157
|
+
|
158
|
+
|
159
|
+
class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
160
|
+
def forward(self, *args):
|
161
|
+
if self.phase == "decode":
|
162
|
+
if self.use_attention_mask:
|
163
|
+
(
|
164
|
+
input_ids_or_inputs_embeds,
|
165
|
+
cache_position,
|
166
|
+
attention_mask,
|
167
|
+
block_tables,
|
168
|
+
position_emb,
|
169
|
+
*past_key_values,
|
170
|
+
) = args
|
171
|
+
else:
|
172
|
+
(
|
173
|
+
input_ids_or_inputs_embeds,
|
174
|
+
cache_position,
|
175
|
+
block_tables,
|
176
|
+
position_emb,
|
177
|
+
*past_key_values,
|
178
|
+
) = args
|
179
|
+
attention_mask = None
|
180
|
+
query_position = None
|
181
|
+
elif self.phase == "prefill":
|
182
|
+
if self.use_attention_mask:
|
183
|
+
(
|
184
|
+
input_ids_or_inputs_embeds,
|
185
|
+
cache_position,
|
186
|
+
attention_mask,
|
187
|
+
query_position,
|
188
|
+
block_tables,
|
189
|
+
position_emb,
|
190
|
+
*past_key_values,
|
191
|
+
) = args
|
192
|
+
else:
|
193
|
+
(
|
194
|
+
input_ids_or_inputs_embeds,
|
195
|
+
cache_position,
|
196
|
+
query_position,
|
197
|
+
block_tables,
|
198
|
+
position_emb,
|
199
|
+
*past_key_values,
|
200
|
+
) = args
|
201
|
+
attention_mask = None
|
202
|
+
|
203
|
+
else:
|
204
|
+
raise ValueError(f"Unknown phase: {self.phase}")
|
205
|
+
|
206
|
+
return self.forward_common(
|
207
|
+
input_ids_or_inputs_embeds,
|
208
|
+
cache_position,
|
209
|
+
attention_mask,
|
210
|
+
query_position,
|
211
|
+
block_tables,
|
212
|
+
position_emb,
|
213
|
+
*past_key_values,
|
214
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
|
19
|
+
from ....configuration_utils import RBLNModelConfig
|
20
|
+
from ....utils.logging import get_logger
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger()
|
24
|
+
|
25
|
+
|
26
|
+
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
batch_size: Optional[int] = None,
|
30
|
+
enc_max_seq_len: Optional[int] = None,
|
31
|
+
dec_max_seq_len: Optional[int] = None,
|
32
|
+
use_attention_mask: Optional[bool] = None,
|
33
|
+
pad_token_id: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
40
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
41
|
+
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
42
|
+
This is automatically set to True for RBLN-CA02 devices.
|
43
|
+
pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
|
44
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If batch_size is not a positive integer.
|
48
|
+
"""
|
49
|
+
super().__init__(**kwargs)
|
50
|
+
self.batch_size = batch_size or 1
|
51
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
52
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
53
|
+
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|
65
|
+
|
66
|
+
self.pad_token_id = pad_token_id
|
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
24
24
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
25
26
|
from ....modeling import RBLNModel
|
26
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
27
27
|
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
|
+
from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
|
29
30
|
|
30
31
|
|
31
32
|
logger = get_logger(__name__)
|
@@ -118,9 +119,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
118
119
|
support_causal_attn = None
|
119
120
|
|
120
121
|
def __post_init__(self, **kwargs):
|
121
|
-
batch_size = self.rbln_config.
|
122
|
-
dec_max_seq_len = self.rbln_config.
|
123
|
-
self.use_attention_mask = self.rbln_config.
|
122
|
+
batch_size = self.rbln_config.batch_size
|
123
|
+
dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
124
125
|
|
125
126
|
self.encoder = RBLNRuntimeEncoder(
|
126
127
|
runtime=self.model[0],
|
@@ -136,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
136
137
|
|
137
138
|
@classmethod
|
138
139
|
@torch.inference_mode()
|
139
|
-
def get_compiled_model(cls, model: PreTrainedModel, rbln_config:
|
140
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
140
141
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
141
142
|
|
142
143
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -177,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
177
178
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
178
179
|
|
179
180
|
@classmethod
|
180
|
-
def
|
181
|
+
def _update_rbln_config(
|
181
182
|
cls,
|
182
183
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
190
|
-
|
191
|
-
if cls.support_causal_attn:
|
192
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
193
|
-
if rbln_use_attention_mask is None:
|
194
|
-
rbln_use_attention_mask = False
|
195
|
-
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
196
|
-
if rbln_npu == "RBLN-CA02":
|
197
|
-
rbln_use_attention_mask = True
|
198
|
-
else:
|
199
|
-
rbln_use_attention_mask = True
|
184
|
+
model: Optional["PreTrainedModel"] = None,
|
185
|
+
model_config: Optional["PretrainedConfig"] = None,
|
186
|
+
rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
|
187
|
+
) -> RBLNModelForSeq2SeqLMConfig:
|
188
|
+
if not cls.support_causal_attn:
|
189
|
+
rbln_config.use_attention_mask = True
|
200
190
|
|
201
191
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
202
192
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -210,43 +200,44 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
210
200
|
model_config, "max_position_embeddings", None
|
211
201
|
)
|
212
202
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
if max_position_embeddings is not None and
|
231
|
-
raise ValueError("`
|
232
|
-
|
233
|
-
if
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
203
|
+
pad_token_id = getattr(model_config, "pad_token_id", None)
|
204
|
+
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
205
|
+
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
206
|
+
pad_token_id = pad_token_id or -1
|
207
|
+
rbln_config.pad_token_id = pad_token_id
|
208
|
+
|
209
|
+
if rbln_config.enc_max_seq_len is None:
|
210
|
+
enc_max_seq_len = max_position_embeddings
|
211
|
+
for tokenizer in preprocessors:
|
212
|
+
if hasattr(tokenizer, "model_max_length"):
|
213
|
+
enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
|
214
|
+
break
|
215
|
+
|
216
|
+
if enc_max_seq_len is None:
|
217
|
+
raise ValueError("`enc_max_seq_len` should be specified!")
|
218
|
+
rbln_config.enc_max_seq_len = enc_max_seq_len
|
219
|
+
|
220
|
+
if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
|
221
|
+
raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
222
|
+
|
223
|
+
if rbln_config.dec_max_seq_len is None:
|
224
|
+
dec_max_seq_len = max_position_embeddings
|
225
|
+
for tokenizer in preprocessors:
|
226
|
+
if hasattr(tokenizer, "model_max_length"):
|
227
|
+
dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
|
228
|
+
break
|
229
|
+
|
230
|
+
if dec_max_seq_len is None:
|
231
|
+
raise ValueError("`dec_max_seq_len` should be specified!")
|
232
|
+
rbln_config.dec_max_seq_len = dec_max_seq_len
|
233
|
+
|
234
|
+
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
235
|
+
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
245
236
|
|
246
237
|
# model input info
|
247
238
|
enc_input_info = [
|
248
|
-
("input_ids", [1,
|
249
|
-
("attention_mask", [1,
|
239
|
+
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
240
|
+
("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
|
250
241
|
("block_tables", [1], "int16"),
|
251
242
|
]
|
252
243
|
enc_input_info.extend(
|
@@ -254,9 +245,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
254
245
|
(
|
255
246
|
f"cross_key_value_states_{i}",
|
256
247
|
[
|
257
|
-
|
248
|
+
rbln_config.batch_size,
|
258
249
|
n_head,
|
259
|
-
|
250
|
+
rbln_config.enc_max_seq_len,
|
260
251
|
d_kv,
|
261
252
|
],
|
262
253
|
"float32",
|
@@ -266,23 +257,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
266
257
|
)
|
267
258
|
|
268
259
|
dec_input_info = [
|
269
|
-
("input_ids", [
|
270
|
-
("encoder_attention_mask", [
|
260
|
+
("input_ids", [rbln_config.batch_size, 1], "int64"),
|
261
|
+
("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
|
271
262
|
(
|
272
263
|
"cache_position",
|
273
|
-
[
|
264
|
+
[rbln_config.batch_size, 1],
|
274
265
|
"int32",
|
275
266
|
),
|
276
|
-
("block_tables", [
|
267
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
277
268
|
]
|
278
269
|
dec_input_info.extend(
|
279
270
|
[
|
280
271
|
(
|
281
272
|
f"cross_key_value_states_{i}",
|
282
273
|
[
|
283
|
-
|
274
|
+
rbln_config.batch_size,
|
284
275
|
n_head,
|
285
|
-
|
276
|
+
rbln_config.enc_max_seq_len,
|
286
277
|
d_kv,
|
287
278
|
],
|
288
279
|
"float32",
|
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
295
286
|
(
|
296
287
|
f"self_key_value_states_{i}",
|
297
288
|
[
|
298
|
-
|
289
|
+
rbln_config.batch_size,
|
299
290
|
n_head,
|
300
|
-
|
291
|
+
rbln_config.dec_max_seq_len,
|
301
292
|
d_kv,
|
302
293
|
],
|
303
294
|
"float32",
|
@@ -306,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
306
297
|
]
|
307
298
|
)
|
308
299
|
|
309
|
-
if
|
310
|
-
dec_input_info.insert(
|
300
|
+
if rbln_config.use_attention_mask:
|
301
|
+
dec_input_info.insert(
|
302
|
+
1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
303
|
+
)
|
311
304
|
|
312
305
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
313
306
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
314
307
|
|
315
|
-
rbln_config
|
316
|
-
rbln_cls=cls.__name__,
|
317
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
318
|
-
rbln_kwargs=rbln_kwargs,
|
319
|
-
)
|
320
|
-
|
321
|
-
rbln_config.model_cfg.update(
|
322
|
-
{
|
323
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
324
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
325
|
-
"batch_size": rbln_batch_size,
|
326
|
-
"pad_token_id": rbln_pad_token_id,
|
327
|
-
"use_attention_mask": rbln_use_attention_mask,
|
328
|
-
}
|
329
|
-
)
|
330
|
-
|
308
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
331
309
|
return rbln_config
|
332
310
|
|
333
311
|
@classmethod
|
334
312
|
def _create_runtimes(
|
335
313
|
cls,
|
336
314
|
compiled_models: List[rebel.RBLNCompiledModel],
|
337
|
-
|
338
|
-
activate_profiler: Optional[bool] = None,
|
315
|
+
rbln_config: RBLNModelForSeq2SeqLMConfig,
|
339
316
|
) -> List[rebel.Runtime]:
|
340
|
-
if any(model_name not in
|
317
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
341
318
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
342
319
|
|
343
320
|
return [
|
344
|
-
|
345
|
-
|
321
|
+
rebel.Runtime(
|
322
|
+
compiled_models[0],
|
323
|
+
tensor_type="pt",
|
324
|
+
device=rbln_config.device_map["encoder"],
|
325
|
+
activate_profiler=rbln_config.activate_profiler,
|
346
326
|
),
|
347
|
-
|
348
|
-
|
327
|
+
rebel.Runtime(
|
328
|
+
compiled_models[1],
|
329
|
+
tensor_type="pt",
|
330
|
+
device=rbln_config.device_map["decoder"],
|
331
|
+
activate_profiler=rbln_config.activate_profiler,
|
349
332
|
),
|
350
333
|
]
|
351
334
|
|
@@ -367,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
367
350
|
):
|
368
351
|
cur_seq_len = input_ids.shape[-1]
|
369
352
|
cache_position = cur_seq_len - 1
|
370
|
-
max_seq_len = self.rbln_config.
|
353
|
+
max_seq_len = self.rbln_config.dec_max_seq_len
|
371
354
|
decoder_batch_size = input_ids.shape[0]
|
372
355
|
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
373
356
|
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
@@ -387,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
387
370
|
**kwargs,
|
388
371
|
) -> Tuple[torch.FloatTensor]:
|
389
372
|
# common decoder
|
390
|
-
cache_position = torch.full((self.rbln_config.
|
373
|
+
cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
|
391
374
|
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
392
375
|
|
393
376
|
return Seq2SeqLMOutput(
|
@@ -421,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
421
404
|
batch_size, input_len = inputs_tensor.shape
|
422
405
|
inputs_tensor = torch.nn.functional.pad(
|
423
406
|
inputs_tensor,
|
424
|
-
(0, self.rbln_config.
|
425
|
-
value=self.rbln_config.
|
407
|
+
(0, self.rbln_config.enc_max_seq_len - input_len),
|
408
|
+
value=self.rbln_config.pad_token_id,
|
426
409
|
)
|
427
410
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
428
|
-
model_kwargs["attention_mask"], (0, self.rbln_config.
|
411
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
429
412
|
)
|
430
413
|
|
431
414
|
# 3. make sure that encoder returns `ModelOutput`
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
24
|
+
pass
|