optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
from ..decoderonly.decoderonly_architecture import (
|
8
|
+
DecoderOnlyWrapper,
|
9
|
+
apply_rotary_pos_emb,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class Qwen2_5_VisionTransformerWrapper(nn.Module):
|
14
|
+
def __init__(self, model: torch.nn.Module):
|
15
|
+
super().__init__()
|
16
|
+
self._original_mod = model
|
17
|
+
self.fullatt_block_indexes = model.fullatt_block_indexes
|
18
|
+
self.merger = model.merger
|
19
|
+
window_seq_len = (model.window_size // model.patch_size) ** 2
|
20
|
+
self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
|
21
|
+
|
22
|
+
def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
|
23
|
+
wrapped_blocks = []
|
24
|
+
for i, block in enumerate(blocks):
|
25
|
+
is_full_attn = True if i in self.fullatt_block_indexes else False
|
26
|
+
wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
|
27
|
+
return nn.ModuleList(wrapped_blocks)
|
28
|
+
|
29
|
+
def forward(
|
30
|
+
self,
|
31
|
+
hidden_states: torch.Tensor,
|
32
|
+
full_attn_masks: torch.Tensor,
|
33
|
+
window_attn_masks: torch.Tensor,
|
34
|
+
cos: torch.Tensor,
|
35
|
+
sin: torch.Tensor,
|
36
|
+
):
|
37
|
+
full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
|
38
|
+
window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
|
39
|
+
|
40
|
+
for i, block in enumerate(self.blocks):
|
41
|
+
attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
|
42
|
+
hidden_states = block(hidden_states, attn_masks, [cos, sin])
|
43
|
+
|
44
|
+
hidden_states = self.merger(hidden_states)
|
45
|
+
|
46
|
+
return hidden_states
|
47
|
+
|
48
|
+
|
49
|
+
class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
50
|
+
def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
|
51
|
+
super().__init__()
|
52
|
+
self._origin_model = model
|
53
|
+
self.norm1 = model.norm1
|
54
|
+
self.norm2 = model.norm2
|
55
|
+
|
56
|
+
if is_full_attn:
|
57
|
+
self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
|
58
|
+
else:
|
59
|
+
self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
|
60
|
+
self.mlp = model.mlp
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
self,
|
64
|
+
hidden_states: torch.Tensor,
|
65
|
+
attn_masks: torch.Tensor,
|
66
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
67
|
+
) -> torch.Tensor:
|
68
|
+
hidden_states = hidden_states + self.attn(
|
69
|
+
self.norm1(hidden_states),
|
70
|
+
attn_masks,
|
71
|
+
position_embeddings,
|
72
|
+
)
|
73
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
74
|
+
return hidden_states
|
75
|
+
|
76
|
+
|
77
|
+
class Qwen2_5_VLVisionFullAttention(nn.Module):
|
78
|
+
def __init__(self, model: nn.Module) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self._origin_model = model
|
81
|
+
self.num_heads = model.num_heads
|
82
|
+
self.head_dim = model.head_dim
|
83
|
+
self.qkv = model.qkv
|
84
|
+
self.proj = model.proj
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
hidden_states: torch.Tensor,
|
89
|
+
attn_masks: torch.Tensor,
|
90
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
91
|
+
) -> torch.Tensor:
|
92
|
+
seq_length = hidden_states.shape[0]
|
93
|
+
hidden_states = hidden_states.unsqueeze(0)
|
94
|
+
q, k, v = (
|
95
|
+
self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
|
96
|
+
)
|
97
|
+
|
98
|
+
cos, sin = position_embeddings
|
99
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
100
|
+
|
101
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
102
|
+
attn_weights = attn_weights + attn_masks
|
103
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
104
|
+
attn_output = torch.matmul(attn_weights, v)
|
105
|
+
attn_output = attn_output.transpose(1, 2)
|
106
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
107
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
108
|
+
|
109
|
+
return attn_output
|
110
|
+
|
111
|
+
|
112
|
+
class Qwen2_5_VLVisionWindowAttention(nn.Module):
|
113
|
+
def __init__(self, model: nn.Module, window_seq_len: int) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self._origin_model = model
|
116
|
+
self.num_heads = model.num_heads
|
117
|
+
self.head_dim = model.head_dim
|
118
|
+
self.qkv = model.qkv
|
119
|
+
self.proj = model.proj
|
120
|
+
self.window_seq_len = window_seq_len
|
121
|
+
|
122
|
+
def forward(
|
123
|
+
self,
|
124
|
+
hidden_states: torch.Tensor,
|
125
|
+
attn_masks: torch.Tensor,
|
126
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
127
|
+
) -> torch.Tensor:
|
128
|
+
seq_length = hidden_states.shape[0]
|
129
|
+
num_windows = seq_length // self.window_seq_len
|
130
|
+
|
131
|
+
window_hidden_states = []
|
132
|
+
for i in range(0, seq_length, self.window_seq_len):
|
133
|
+
window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
|
134
|
+
hidden_states = torch.stack(window_hidden_states)
|
135
|
+
|
136
|
+
q, k, v = (
|
137
|
+
self.qkv(hidden_states)
|
138
|
+
.reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
|
139
|
+
.permute(2, 0, 3, 1, 4)
|
140
|
+
.unbind(0)
|
141
|
+
)
|
142
|
+
cos, sin = position_embeddings
|
143
|
+
cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
|
144
|
+
sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
|
145
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
146
|
+
|
147
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
148
|
+
|
149
|
+
attn_weights = attn_weights + attn_masks
|
150
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
151
|
+
attn_output = torch.matmul(attn_weights, v)
|
152
|
+
attn_output = attn_output.transpose(1, 2)
|
153
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
154
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
155
|
+
|
156
|
+
return attn_output
|
157
|
+
|
158
|
+
|
159
|
+
class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
160
|
+
def forward(self, *args):
|
161
|
+
if self.phase == "decode":
|
162
|
+
if self.use_attention_mask:
|
163
|
+
(
|
164
|
+
input_ids_or_inputs_embeds,
|
165
|
+
cache_position,
|
166
|
+
attention_mask,
|
167
|
+
block_tables,
|
168
|
+
position_emb,
|
169
|
+
*past_key_values,
|
170
|
+
) = args
|
171
|
+
else:
|
172
|
+
(
|
173
|
+
input_ids_or_inputs_embeds,
|
174
|
+
cache_position,
|
175
|
+
block_tables,
|
176
|
+
position_emb,
|
177
|
+
*past_key_values,
|
178
|
+
) = args
|
179
|
+
attention_mask = None
|
180
|
+
query_position = None
|
181
|
+
elif self.phase == "prefill":
|
182
|
+
if self.use_attention_mask:
|
183
|
+
(
|
184
|
+
input_ids_or_inputs_embeds,
|
185
|
+
cache_position,
|
186
|
+
attention_mask,
|
187
|
+
query_position,
|
188
|
+
block_tables,
|
189
|
+
position_emb,
|
190
|
+
*past_key_values,
|
191
|
+
) = args
|
192
|
+
else:
|
193
|
+
(
|
194
|
+
input_ids_or_inputs_embeds,
|
195
|
+
cache_position,
|
196
|
+
query_position,
|
197
|
+
block_tables,
|
198
|
+
position_emb,
|
199
|
+
*past_key_values,
|
200
|
+
) = args
|
201
|
+
attention_mask = None
|
202
|
+
|
203
|
+
else:
|
204
|
+
raise ValueError(f"Unknown phase: {self.phase}")
|
205
|
+
|
206
|
+
return self.forward_common(
|
207
|
+
input_ids_or_inputs_embeds,
|
208
|
+
cache_position,
|
209
|
+
attention_mask,
|
210
|
+
query_position,
|
211
|
+
block_tables,
|
212
|
+
position_emb,
|
213
|
+
*past_key_values,
|
214
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
|
19
|
+
from ....configuration_utils import RBLNModelConfig
|
20
|
+
from ....utils.logging import get_logger
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger()
|
24
|
+
|
25
|
+
|
26
|
+
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
batch_size: Optional[int] = None,
|
30
|
+
enc_max_seq_len: Optional[int] = None,
|
31
|
+
dec_max_seq_len: Optional[int] = None,
|
32
|
+
use_attention_mask: Optional[bool] = None,
|
33
|
+
pad_token_id: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
40
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
41
|
+
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
42
|
+
This is automatically set to True for RBLN-CA02 devices.
|
43
|
+
pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
|
44
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If batch_size is not a positive integer.
|
48
|
+
"""
|
49
|
+
super().__init__(**kwargs)
|
50
|
+
self.batch_size = batch_size or 1
|
51
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
52
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
53
|
+
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|
65
|
+
|
66
|
+
self.pad_token_id = pad_token_id
|
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
24
24
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
25
26
|
from ....modeling import RBLNModel
|
26
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
27
27
|
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
|
+
from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
|
29
30
|
|
30
31
|
|
31
32
|
logger = get_logger(__name__)
|
@@ -38,8 +39,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
38
39
|
mandatory_members = ["main_input_name"]
|
39
40
|
|
40
41
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
41
|
-
|
42
|
-
return BaseModelOutput(last_hidden_state=
|
42
|
+
output = super().forward(*args, **kwargs)
|
43
|
+
return BaseModelOutput(last_hidden_state=output)
|
43
44
|
|
44
45
|
|
45
46
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
@@ -94,7 +95,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
94
95
|
decoder_attention_mask if self.use_attention_mask else None,
|
95
96
|
attention_mask,
|
96
97
|
cache_position,
|
97
|
-
block_tables,
|
98
|
+
block_tables=block_tables,
|
98
99
|
)
|
99
100
|
|
100
101
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -115,11 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
115
116
|
|
116
117
|
main_input_name = "input_ids"
|
117
118
|
auto_model_class = AutoModelForSeq2SeqLM
|
119
|
+
support_causal_attn = None
|
118
120
|
|
119
121
|
def __post_init__(self, **kwargs):
|
120
|
-
batch_size = self.rbln_config.
|
121
|
-
dec_max_seq_len = self.rbln_config.
|
122
|
-
self.use_attention_mask = self.rbln_config.
|
122
|
+
batch_size = self.rbln_config.batch_size
|
123
|
+
dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
123
125
|
|
124
126
|
self.encoder = RBLNRuntimeEncoder(
|
125
127
|
runtime=self.model[0],
|
@@ -135,7 +137,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
135
137
|
|
136
138
|
@classmethod
|
137
139
|
@torch.inference_mode()
|
138
|
-
def get_compiled_model(cls, model: PreTrainedModel, rbln_config:
|
140
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
139
141
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
140
142
|
|
141
143
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -176,23 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
176
178
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
177
179
|
|
178
180
|
@classmethod
|
179
|
-
def
|
181
|
+
def _update_rbln_config(
|
180
182
|
cls,
|
181
183
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
189
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
190
|
-
|
191
|
-
if rbln_use_attention_mask is None:
|
192
|
-
rbln_use_attention_mask = False
|
193
|
-
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
194
|
-
if rbln_npu == "RBLN-CA02":
|
195
|
-
rbln_use_attention_mask = True
|
184
|
+
model: Optional["PreTrainedModel"] = None,
|
185
|
+
model_config: Optional["PretrainedConfig"] = None,
|
186
|
+
rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
|
187
|
+
) -> RBLNModelForSeq2SeqLMConfig:
|
188
|
+
if not cls.support_causal_attn:
|
189
|
+
rbln_config.use_attention_mask = True
|
196
190
|
|
197
191
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
198
192
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -206,84 +200,85 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
206
200
|
model_config, "max_position_embeddings", None
|
207
201
|
)
|
208
202
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
if max_position_embeddings is not None and
|
227
|
-
raise ValueError("`
|
228
|
-
|
229
|
-
if
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
203
|
+
pad_token_id = getattr(model_config, "pad_token_id", None)
|
204
|
+
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
205
|
+
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
206
|
+
pad_token_id = pad_token_id or -1
|
207
|
+
rbln_config.pad_token_id = pad_token_id
|
208
|
+
|
209
|
+
if rbln_config.enc_max_seq_len is None:
|
210
|
+
enc_max_seq_len = max_position_embeddings
|
211
|
+
for tokenizer in preprocessors:
|
212
|
+
if hasattr(tokenizer, "model_max_length"):
|
213
|
+
enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
|
214
|
+
break
|
215
|
+
|
216
|
+
if enc_max_seq_len is None:
|
217
|
+
raise ValueError("`enc_max_seq_len` should be specified!")
|
218
|
+
rbln_config.enc_max_seq_len = enc_max_seq_len
|
219
|
+
|
220
|
+
if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
|
221
|
+
raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
222
|
+
|
223
|
+
if rbln_config.dec_max_seq_len is None:
|
224
|
+
dec_max_seq_len = max_position_embeddings
|
225
|
+
for tokenizer in preprocessors:
|
226
|
+
if hasattr(tokenizer, "model_max_length"):
|
227
|
+
dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
|
228
|
+
break
|
229
|
+
|
230
|
+
if dec_max_seq_len is None:
|
231
|
+
raise ValueError("`dec_max_seq_len` should be specified!")
|
232
|
+
rbln_config.dec_max_seq_len = dec_max_seq_len
|
233
|
+
|
234
|
+
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
235
|
+
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
241
236
|
|
242
237
|
# model input info
|
243
238
|
enc_input_info = [
|
244
|
-
("input_ids", [1,
|
245
|
-
("attention_mask", [1,
|
246
|
-
(
|
247
|
-
"cross_key_value_states",
|
248
|
-
[
|
249
|
-
n_layer * 2,
|
250
|
-
rbln_batch_size,
|
251
|
-
n_head,
|
252
|
-
rbln_enc_max_seq_len,
|
253
|
-
d_kv,
|
254
|
-
],
|
255
|
-
"float32",
|
256
|
-
),
|
239
|
+
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
240
|
+
("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
|
257
241
|
("block_tables", [1], "int16"),
|
258
242
|
]
|
243
|
+
enc_input_info.extend(
|
244
|
+
[
|
245
|
+
(
|
246
|
+
f"cross_key_value_states_{i}",
|
247
|
+
[
|
248
|
+
rbln_config.batch_size,
|
249
|
+
n_head,
|
250
|
+
rbln_config.enc_max_seq_len,
|
251
|
+
d_kv,
|
252
|
+
],
|
253
|
+
"float32",
|
254
|
+
)
|
255
|
+
for i in range(n_layer * 2)
|
256
|
+
]
|
257
|
+
)
|
259
258
|
|
260
259
|
dec_input_info = [
|
261
|
-
("input_ids", [
|
262
|
-
("encoder_attention_mask", [
|
260
|
+
("input_ids", [rbln_config.batch_size, 1], "int64"),
|
261
|
+
("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
|
263
262
|
(
|
264
263
|
"cache_position",
|
265
|
-
[
|
264
|
+
[rbln_config.batch_size, 1],
|
266
265
|
"int32",
|
267
266
|
),
|
268
|
-
(
|
269
|
-
"block_tables",
|
270
|
-
[rbln_batch_size, 1],
|
271
|
-
"int16",
|
272
|
-
),
|
267
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
273
268
|
]
|
274
269
|
dec_input_info.extend(
|
275
270
|
[
|
276
271
|
(
|
277
|
-
"
|
272
|
+
f"cross_key_value_states_{i}",
|
278
273
|
[
|
279
|
-
|
280
|
-
rbln_batch_size,
|
274
|
+
rbln_config.batch_size,
|
281
275
|
n_head,
|
282
|
-
|
276
|
+
rbln_config.enc_max_seq_len,
|
283
277
|
d_kv,
|
284
278
|
],
|
285
279
|
"float32",
|
286
280
|
)
|
281
|
+
for i in range(n_layer * 2)
|
287
282
|
]
|
288
283
|
)
|
289
284
|
dec_input_info.extend(
|
@@ -291,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
291
286
|
(
|
292
287
|
f"self_key_value_states_{i}",
|
293
288
|
[
|
294
|
-
|
289
|
+
rbln_config.batch_size,
|
295
290
|
n_head,
|
296
|
-
|
291
|
+
rbln_config.dec_max_seq_len,
|
297
292
|
d_kv,
|
298
293
|
],
|
299
294
|
"float32",
|
@@ -302,46 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
302
297
|
]
|
303
298
|
)
|
304
299
|
|
305
|
-
if
|
306
|
-
dec_input_info.insert(
|
300
|
+
if rbln_config.use_attention_mask:
|
301
|
+
dec_input_info.insert(
|
302
|
+
1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
303
|
+
)
|
307
304
|
|
308
305
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
309
306
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
310
307
|
|
311
|
-
rbln_config
|
312
|
-
rbln_cls=cls.__name__,
|
313
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
314
|
-
rbln_kwargs=rbln_kwargs,
|
315
|
-
)
|
316
|
-
|
317
|
-
rbln_config.model_cfg.update(
|
318
|
-
{
|
319
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
320
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
321
|
-
"batch_size": rbln_batch_size,
|
322
|
-
"pad_token_id": rbln_pad_token_id,
|
323
|
-
"use_attention_mask": rbln_use_attention_mask,
|
324
|
-
}
|
325
|
-
)
|
326
|
-
|
308
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
327
309
|
return rbln_config
|
328
310
|
|
329
311
|
@classmethod
|
330
312
|
def _create_runtimes(
|
331
313
|
cls,
|
332
314
|
compiled_models: List[rebel.RBLNCompiledModel],
|
333
|
-
|
334
|
-
activate_profiler: Optional[bool] = None,
|
315
|
+
rbln_config: RBLNModelForSeq2SeqLMConfig,
|
335
316
|
) -> List[rebel.Runtime]:
|
336
|
-
if any(model_name not in
|
317
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
337
318
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
338
319
|
|
339
320
|
return [
|
340
|
-
|
341
|
-
|
321
|
+
rebel.Runtime(
|
322
|
+
compiled_models[0],
|
323
|
+
tensor_type="pt",
|
324
|
+
device=rbln_config.device_map["encoder"],
|
325
|
+
activate_profiler=rbln_config.activate_profiler,
|
342
326
|
),
|
343
|
-
|
344
|
-
|
327
|
+
rebel.Runtime(
|
328
|
+
compiled_models[1],
|
329
|
+
tensor_type="pt",
|
330
|
+
device=rbln_config.device_map["decoder"],
|
331
|
+
activate_profiler=rbln_config.activate_profiler,
|
345
332
|
),
|
346
333
|
]
|
347
334
|
|
@@ -363,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
363
350
|
):
|
364
351
|
cur_seq_len = input_ids.shape[-1]
|
365
352
|
cache_position = cur_seq_len - 1
|
366
|
-
max_seq_len = self.rbln_config.
|
353
|
+
max_seq_len = self.rbln_config.dec_max_seq_len
|
367
354
|
decoder_batch_size = input_ids.shape[0]
|
368
355
|
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
369
356
|
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
@@ -383,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
383
370
|
**kwargs,
|
384
371
|
) -> Tuple[torch.FloatTensor]:
|
385
372
|
# common decoder
|
386
|
-
cache_position = torch.full((self.rbln_config.
|
373
|
+
cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
|
387
374
|
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
388
375
|
|
389
376
|
return Seq2SeqLMOutput(
|
@@ -417,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
417
404
|
batch_size, input_len = inputs_tensor.shape
|
418
405
|
inputs_tensor = torch.nn.functional.pad(
|
419
406
|
inputs_tensor,
|
420
|
-
(0, self.rbln_config.
|
421
|
-
value=self.rbln_config.
|
407
|
+
(0, self.rbln_config.enc_max_seq_len - input_len),
|
408
|
+
value=self.rbln_config.pad_token_id,
|
422
409
|
)
|
423
410
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
424
|
-
model_kwargs["attention_mask"], (0, self.rbln_config.
|
411
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
425
412
|
)
|
426
413
|
|
427
414
|
# 3. make sure that encoder returns `ModelOutput`
|