optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
from ..decoderonly.decoderonly_architecture import (
|
8
|
+
DecoderOnlyWrapper,
|
9
|
+
apply_rotary_pos_emb,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class Qwen2_5_VisionTransformerWrapper(nn.Module):
|
14
|
+
def __init__(self, model: torch.nn.Module):
|
15
|
+
super().__init__()
|
16
|
+
self._original_mod = model
|
17
|
+
self.fullatt_block_indexes = model.fullatt_block_indexes
|
18
|
+
self.merger = model.merger
|
19
|
+
window_seq_len = (model.window_size // model.patch_size) ** 2
|
20
|
+
self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
|
21
|
+
|
22
|
+
def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
|
23
|
+
wrapped_blocks = []
|
24
|
+
for i, block in enumerate(blocks):
|
25
|
+
is_full_attn = True if i in self.fullatt_block_indexes else False
|
26
|
+
wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
|
27
|
+
return nn.ModuleList(wrapped_blocks)
|
28
|
+
|
29
|
+
def forward(
|
30
|
+
self,
|
31
|
+
hidden_states: torch.Tensor,
|
32
|
+
full_attn_masks: torch.Tensor,
|
33
|
+
window_attn_masks: torch.Tensor,
|
34
|
+
cos: torch.Tensor,
|
35
|
+
sin: torch.Tensor,
|
36
|
+
):
|
37
|
+
full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
|
38
|
+
window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
|
39
|
+
|
40
|
+
for i, block in enumerate(self.blocks):
|
41
|
+
attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
|
42
|
+
hidden_states = block(hidden_states, attn_masks, [cos, sin])
|
43
|
+
|
44
|
+
hidden_states = self.merger(hidden_states)
|
45
|
+
|
46
|
+
return hidden_states
|
47
|
+
|
48
|
+
|
49
|
+
class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
50
|
+
def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
|
51
|
+
super().__init__()
|
52
|
+
self._origin_model = model
|
53
|
+
self.norm1 = model.norm1
|
54
|
+
self.norm2 = model.norm2
|
55
|
+
|
56
|
+
if is_full_attn:
|
57
|
+
self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
|
58
|
+
else:
|
59
|
+
self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
|
60
|
+
self.mlp = model.mlp
|
61
|
+
|
62
|
+
def forward(
|
63
|
+
self,
|
64
|
+
hidden_states: torch.Tensor,
|
65
|
+
attn_masks: torch.Tensor,
|
66
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
67
|
+
) -> torch.Tensor:
|
68
|
+
hidden_states = hidden_states + self.attn(
|
69
|
+
self.norm1(hidden_states),
|
70
|
+
attn_masks,
|
71
|
+
position_embeddings,
|
72
|
+
)
|
73
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
74
|
+
return hidden_states
|
75
|
+
|
76
|
+
|
77
|
+
class Qwen2_5_VLVisionFullAttention(nn.Module):
|
78
|
+
def __init__(self, model: nn.Module) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self._origin_model = model
|
81
|
+
self.num_heads = model.num_heads
|
82
|
+
self.head_dim = model.head_dim
|
83
|
+
self.qkv = model.qkv
|
84
|
+
self.proj = model.proj
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
hidden_states: torch.Tensor,
|
89
|
+
attn_masks: torch.Tensor,
|
90
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
91
|
+
) -> torch.Tensor:
|
92
|
+
seq_length = hidden_states.shape[0]
|
93
|
+
hidden_states = hidden_states.unsqueeze(0)
|
94
|
+
q, k, v = (
|
95
|
+
self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
|
96
|
+
)
|
97
|
+
|
98
|
+
cos, sin = position_embeddings
|
99
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
100
|
+
|
101
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
102
|
+
attn_weights = attn_weights + attn_masks
|
103
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
104
|
+
attn_output = torch.matmul(attn_weights, v)
|
105
|
+
attn_output = attn_output.transpose(1, 2)
|
106
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
107
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
108
|
+
|
109
|
+
return attn_output
|
110
|
+
|
111
|
+
|
112
|
+
class Qwen2_5_VLVisionWindowAttention(nn.Module):
|
113
|
+
def __init__(self, model: nn.Module, window_seq_len: int) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self._origin_model = model
|
116
|
+
self.num_heads = model.num_heads
|
117
|
+
self.head_dim = model.head_dim
|
118
|
+
self.qkv = model.qkv
|
119
|
+
self.proj = model.proj
|
120
|
+
self.window_seq_len = window_seq_len
|
121
|
+
|
122
|
+
def forward(
|
123
|
+
self,
|
124
|
+
hidden_states: torch.Tensor,
|
125
|
+
attn_masks: torch.Tensor,
|
126
|
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
127
|
+
) -> torch.Tensor:
|
128
|
+
seq_length = hidden_states.shape[0]
|
129
|
+
num_windows = seq_length // self.window_seq_len
|
130
|
+
|
131
|
+
window_hidden_states = []
|
132
|
+
for i in range(0, seq_length, self.window_seq_len):
|
133
|
+
window_hidden_states.append(hidden_states[i : i + self.window_seq_len])
|
134
|
+
hidden_states = torch.stack(window_hidden_states)
|
135
|
+
|
136
|
+
q, k, v = (
|
137
|
+
self.qkv(hidden_states)
|
138
|
+
.reshape(num_windows, self.window_seq_len, 3, self.num_heads, -1)
|
139
|
+
.permute(2, 0, 3, 1, 4)
|
140
|
+
.unbind(0)
|
141
|
+
)
|
142
|
+
cos, sin = position_embeddings
|
143
|
+
cos = cos.reshape(num_windows, 1, seq_length // num_windows, -1)
|
144
|
+
sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
|
145
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
146
|
+
|
147
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
148
|
+
|
149
|
+
attn_weights = attn_weights + attn_masks
|
150
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
151
|
+
attn_output = torch.matmul(attn_weights, v)
|
152
|
+
attn_output = attn_output.transpose(1, 2)
|
153
|
+
attn_output = attn_output.reshape(1, seq_length, -1)
|
154
|
+
attn_output = self.proj(attn_output).squeeze(0)
|
155
|
+
|
156
|
+
return attn_output
|
157
|
+
|
158
|
+
|
159
|
+
class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
160
|
+
def forward(self, *args):
|
161
|
+
if self.phase == "decode":
|
162
|
+
if self.use_attention_mask:
|
163
|
+
(
|
164
|
+
input_ids_or_inputs_embeds,
|
165
|
+
cache_position,
|
166
|
+
attention_mask,
|
167
|
+
block_tables,
|
168
|
+
position_emb,
|
169
|
+
*past_key_values,
|
170
|
+
) = args
|
171
|
+
else:
|
172
|
+
(
|
173
|
+
input_ids_or_inputs_embeds,
|
174
|
+
cache_position,
|
175
|
+
block_tables,
|
176
|
+
position_emb,
|
177
|
+
*past_key_values,
|
178
|
+
) = args
|
179
|
+
attention_mask = None
|
180
|
+
query_position = None
|
181
|
+
elif self.phase == "prefill":
|
182
|
+
if self.use_attention_mask:
|
183
|
+
(
|
184
|
+
input_ids_or_inputs_embeds,
|
185
|
+
cache_position,
|
186
|
+
attention_mask,
|
187
|
+
query_position,
|
188
|
+
block_tables,
|
189
|
+
position_emb,
|
190
|
+
*past_key_values,
|
191
|
+
) = args
|
192
|
+
else:
|
193
|
+
(
|
194
|
+
input_ids_or_inputs_embeds,
|
195
|
+
cache_position,
|
196
|
+
query_position,
|
197
|
+
block_tables,
|
198
|
+
position_emb,
|
199
|
+
*past_key_values,
|
200
|
+
) = args
|
201
|
+
attention_mask = None
|
202
|
+
|
203
|
+
else:
|
204
|
+
raise ValueError(f"Unknown phase: {self.phase}")
|
205
|
+
|
206
|
+
return self.forward_common(
|
207
|
+
input_ids_or_inputs_embeds,
|
208
|
+
cache_position,
|
209
|
+
attention_mask,
|
210
|
+
query_position,
|
211
|
+
block_tables,
|
212
|
+
position_emb,
|
213
|
+
*past_key_values,
|
214
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
|
19
|
+
from ....configuration_utils import RBLNModelConfig
|
20
|
+
from ....utils.logging import get_logger
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger()
|
24
|
+
|
25
|
+
|
26
|
+
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
batch_size: Optional[int] = None,
|
30
|
+
enc_max_seq_len: Optional[int] = None,
|
31
|
+
dec_max_seq_len: Optional[int] = None,
|
32
|
+
use_attention_mask: Optional[bool] = None,
|
33
|
+
pad_token_id: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
40
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
41
|
+
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
42
|
+
This is automatically set to True for RBLN-CA02 devices.
|
43
|
+
pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
|
44
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If batch_size is not a positive integer.
|
48
|
+
"""
|
49
|
+
super().__init__(**kwargs)
|
50
|
+
self.batch_size = batch_size or 1
|
51
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
52
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
53
|
+
|
54
|
+
self.enc_max_seq_len = enc_max_seq_len
|
55
|
+
self.dec_max_seq_len = dec_max_seq_len
|
56
|
+
|
57
|
+
self.use_attention_mask = use_attention_mask
|
58
|
+
npu = self.npu or rebel.get_npu_name()
|
59
|
+
if npu == "RBLN-CA02":
|
60
|
+
if self.use_attention_mask is False:
|
61
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
62
|
+
self.use_attention_mask = True
|
63
|
+
else:
|
64
|
+
self.use_attention_mask = self.use_attention_mask or False
|
65
|
+
|
66
|
+
self.pad_token_id = pad_token_id
|
@@ -22,10 +22,11 @@ from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
24
24
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
25
26
|
from ....modeling import RBLNModel
|
26
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
27
27
|
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
|
+
from .configuration_seq2seq2 import RBLNModelForSeq2SeqLMConfig
|
29
30
|
|
30
31
|
|
31
32
|
logger = get_logger(__name__)
|
@@ -38,8 +39,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
38
39
|
mandatory_members = ["main_input_name"]
|
39
40
|
|
40
41
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
41
|
-
|
42
|
-
return BaseModelOutput(last_hidden_state=
|
42
|
+
output = super().forward(*args, **kwargs)
|
43
|
+
return BaseModelOutput(last_hidden_state=output)
|
43
44
|
|
44
45
|
|
45
46
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
@@ -50,7 +51,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
50
51
|
runtime: rebel.Runtime,
|
51
52
|
batch_size: int,
|
52
53
|
dec_max_seq_len: int,
|
53
|
-
support_paged_causal_attn: Optional[bool] = None,
|
54
54
|
use_attention_mask: Optional[bool] = None,
|
55
55
|
**kwargs: Any,
|
56
56
|
) -> None:
|
@@ -58,10 +58,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
58
58
|
self.batch_size = batch_size
|
59
59
|
self.dec_max_seq_len = dec_max_seq_len
|
60
60
|
self.use_attention_mask = use_attention_mask
|
61
|
-
|
62
|
-
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
63
|
-
else:
|
64
|
-
self.default_block_tables = None
|
61
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
65
62
|
|
66
63
|
def forward(
|
67
64
|
self,
|
@@ -119,12 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
119
116
|
|
120
117
|
main_input_name = "input_ids"
|
121
118
|
auto_model_class = AutoModelForSeq2SeqLM
|
122
|
-
|
119
|
+
support_causal_attn = None
|
123
120
|
|
124
121
|
def __post_init__(self, **kwargs):
|
125
|
-
batch_size = self.rbln_config.
|
126
|
-
dec_max_seq_len = self.rbln_config.
|
127
|
-
self.use_attention_mask = self.rbln_config.
|
122
|
+
batch_size = self.rbln_config.batch_size
|
123
|
+
dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
124
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
128
125
|
|
129
126
|
self.encoder = RBLNRuntimeEncoder(
|
130
127
|
runtime=self.model[0],
|
@@ -135,13 +132,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
135
132
|
main_input_name="input_ids",
|
136
133
|
batch_size=batch_size,
|
137
134
|
dec_max_seq_len=dec_max_seq_len,
|
138
|
-
support_paged_causal_attn=self.support_paged_causal_attn,
|
139
135
|
use_attention_mask=self.use_attention_mask,
|
140
136
|
)
|
141
137
|
|
142
138
|
@classmethod
|
143
139
|
@torch.inference_mode()
|
144
|
-
def get_compiled_model(cls, model: PreTrainedModel, rbln_config:
|
140
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
145
141
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
146
142
|
|
147
143
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
@@ -182,26 +178,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
182
178
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
183
179
|
|
184
180
|
@classmethod
|
185
|
-
def
|
181
|
+
def _update_rbln_config(
|
186
182
|
cls,
|
187
183
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
195
|
-
|
196
|
-
if cls.support_paged_causal_attn:
|
197
|
-
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
198
|
-
if rbln_use_attention_mask is None:
|
199
|
-
rbln_use_attention_mask = False
|
200
|
-
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
201
|
-
if rbln_npu == "RBLN-CA02":
|
202
|
-
rbln_use_attention_mask = True
|
203
|
-
else:
|
204
|
-
rbln_use_attention_mask = True
|
184
|
+
model: Optional["PreTrainedModel"] = None,
|
185
|
+
model_config: Optional["PretrainedConfig"] = None,
|
186
|
+
rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
|
187
|
+
) -> RBLNModelForSeq2SeqLMConfig:
|
188
|
+
if not cls.support_causal_attn:
|
189
|
+
rbln_config.use_attention_mask = True
|
205
190
|
|
206
191
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
207
192
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -215,79 +200,85 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
215
200
|
model_config, "max_position_embeddings", None
|
216
201
|
)
|
217
202
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
if max_position_embeddings is not None and
|
236
|
-
raise ValueError("`
|
237
|
-
|
238
|
-
if
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
203
|
+
pad_token_id = getattr(model_config, "pad_token_id", None)
|
204
|
+
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
205
|
+
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
206
|
+
pad_token_id = pad_token_id or -1
|
207
|
+
rbln_config.pad_token_id = pad_token_id
|
208
|
+
|
209
|
+
if rbln_config.enc_max_seq_len is None:
|
210
|
+
enc_max_seq_len = max_position_embeddings
|
211
|
+
for tokenizer in preprocessors:
|
212
|
+
if hasattr(tokenizer, "model_max_length"):
|
213
|
+
enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
|
214
|
+
break
|
215
|
+
|
216
|
+
if enc_max_seq_len is None:
|
217
|
+
raise ValueError("`enc_max_seq_len` should be specified!")
|
218
|
+
rbln_config.enc_max_seq_len = enc_max_seq_len
|
219
|
+
|
220
|
+
if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
|
221
|
+
raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
222
|
+
|
223
|
+
if rbln_config.dec_max_seq_len is None:
|
224
|
+
dec_max_seq_len = max_position_embeddings
|
225
|
+
for tokenizer in preprocessors:
|
226
|
+
if hasattr(tokenizer, "model_max_length"):
|
227
|
+
dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
|
228
|
+
break
|
229
|
+
|
230
|
+
if dec_max_seq_len is None:
|
231
|
+
raise ValueError("`dec_max_seq_len` should be specified!")
|
232
|
+
rbln_config.dec_max_seq_len = dec_max_seq_len
|
233
|
+
|
234
|
+
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
235
|
+
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
250
236
|
|
251
237
|
# model input info
|
252
238
|
enc_input_info = [
|
253
|
-
("input_ids", [1,
|
254
|
-
("attention_mask", [1,
|
255
|
-
(
|
256
|
-
"cross_key_value_states",
|
257
|
-
[
|
258
|
-
n_layer * 2,
|
259
|
-
rbln_batch_size,
|
260
|
-
n_head,
|
261
|
-
rbln_enc_max_seq_len,
|
262
|
-
d_kv,
|
263
|
-
],
|
264
|
-
"float32",
|
265
|
-
),
|
239
|
+
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
240
|
+
("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
|
266
241
|
("block_tables", [1], "int16"),
|
267
242
|
]
|
243
|
+
enc_input_info.extend(
|
244
|
+
[
|
245
|
+
(
|
246
|
+
f"cross_key_value_states_{i}",
|
247
|
+
[
|
248
|
+
rbln_config.batch_size,
|
249
|
+
n_head,
|
250
|
+
rbln_config.enc_max_seq_len,
|
251
|
+
d_kv,
|
252
|
+
],
|
253
|
+
"float32",
|
254
|
+
)
|
255
|
+
for i in range(n_layer * 2)
|
256
|
+
]
|
257
|
+
)
|
268
258
|
|
269
259
|
dec_input_info = [
|
270
|
-
("input_ids", [
|
271
|
-
("encoder_attention_mask", [
|
260
|
+
("input_ids", [rbln_config.batch_size, 1], "int64"),
|
261
|
+
("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
|
272
262
|
(
|
273
263
|
"cache_position",
|
274
|
-
[
|
264
|
+
[rbln_config.batch_size, 1],
|
275
265
|
"int32",
|
276
266
|
),
|
267
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
277
268
|
]
|
278
269
|
dec_input_info.extend(
|
279
270
|
[
|
280
271
|
(
|
281
|
-
"
|
272
|
+
f"cross_key_value_states_{i}",
|
282
273
|
[
|
283
|
-
|
284
|
-
rbln_batch_size,
|
274
|
+
rbln_config.batch_size,
|
285
275
|
n_head,
|
286
|
-
|
276
|
+
rbln_config.enc_max_seq_len,
|
287
277
|
d_kv,
|
288
278
|
],
|
289
279
|
"float32",
|
290
280
|
)
|
281
|
+
for i in range(n_layer * 2)
|
291
282
|
]
|
292
283
|
)
|
293
284
|
dec_input_info.extend(
|
@@ -295,9 +286,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
295
286
|
(
|
296
287
|
f"self_key_value_states_{i}",
|
297
288
|
[
|
298
|
-
|
289
|
+
rbln_config.batch_size,
|
299
290
|
n_head,
|
300
|
-
|
291
|
+
rbln_config.dec_max_seq_len,
|
301
292
|
d_kv,
|
302
293
|
],
|
303
294
|
"float32",
|
@@ -306,48 +297,38 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
306
297
|
]
|
307
298
|
)
|
308
299
|
|
309
|
-
if
|
310
|
-
dec_input_info.insert(
|
311
|
-
|
312
|
-
|
300
|
+
if rbln_config.use_attention_mask:
|
301
|
+
dec_input_info.insert(
|
302
|
+
1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
303
|
+
)
|
313
304
|
|
314
305
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
315
306
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
316
307
|
|
317
|
-
rbln_config
|
318
|
-
rbln_cls=cls.__name__,
|
319
|
-
compile_cfgs=[enc_compile_config, dec_compile_config],
|
320
|
-
rbln_kwargs=rbln_kwargs,
|
321
|
-
)
|
322
|
-
|
323
|
-
rbln_config.model_cfg.update(
|
324
|
-
{
|
325
|
-
"enc_max_seq_len": rbln_enc_max_seq_len,
|
326
|
-
"dec_max_seq_len": rbln_dec_max_seq_len,
|
327
|
-
"batch_size": rbln_batch_size,
|
328
|
-
"pad_token_id": rbln_pad_token_id,
|
329
|
-
"use_attention_mask": rbln_use_attention_mask,
|
330
|
-
}
|
331
|
-
)
|
332
|
-
|
308
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
333
309
|
return rbln_config
|
334
310
|
|
335
311
|
@classmethod
|
336
312
|
def _create_runtimes(
|
337
313
|
cls,
|
338
314
|
compiled_models: List[rebel.RBLNCompiledModel],
|
339
|
-
|
340
|
-
activate_profiler: Optional[bool] = None,
|
315
|
+
rbln_config: RBLNModelForSeq2SeqLMConfig,
|
341
316
|
) -> List[rebel.Runtime]:
|
342
|
-
if any(model_name not in
|
317
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
343
318
|
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
344
319
|
|
345
320
|
return [
|
346
|
-
|
347
|
-
|
321
|
+
rebel.Runtime(
|
322
|
+
compiled_models[0],
|
323
|
+
tensor_type="pt",
|
324
|
+
device=rbln_config.device_map["encoder"],
|
325
|
+
activate_profiler=rbln_config.activate_profiler,
|
348
326
|
),
|
349
|
-
|
350
|
-
|
327
|
+
rebel.Runtime(
|
328
|
+
compiled_models[1],
|
329
|
+
tensor_type="pt",
|
330
|
+
device=rbln_config.device_map["decoder"],
|
331
|
+
activate_profiler=rbln_config.activate_profiler,
|
351
332
|
),
|
352
333
|
]
|
353
334
|
|
@@ -369,7 +350,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
369
350
|
):
|
370
351
|
cur_seq_len = input_ids.shape[-1]
|
371
352
|
cache_position = cur_seq_len - 1
|
372
|
-
max_seq_len = self.rbln_config.
|
353
|
+
max_seq_len = self.rbln_config.dec_max_seq_len
|
373
354
|
decoder_batch_size = input_ids.shape[0]
|
374
355
|
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
375
356
|
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
@@ -389,7 +370,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
389
370
|
**kwargs,
|
390
371
|
) -> Tuple[torch.FloatTensor]:
|
391
372
|
# common decoder
|
392
|
-
cache_position = torch.full((self.rbln_config.
|
373
|
+
cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
|
393
374
|
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
394
375
|
|
395
376
|
return Seq2SeqLMOutput(
|
@@ -423,11 +404,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
423
404
|
batch_size, input_len = inputs_tensor.shape
|
424
405
|
inputs_tensor = torch.nn.functional.pad(
|
425
406
|
inputs_tensor,
|
426
|
-
(0, self.rbln_config.
|
427
|
-
value=self.rbln_config.
|
407
|
+
(0, self.rbln_config.enc_max_seq_len - input_len),
|
408
|
+
value=self.rbln_config.pad_token_id,
|
428
409
|
)
|
429
410
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
430
|
-
model_kwargs["attention_mask"], (0, self.rbln_config.
|
411
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
431
412
|
)
|
432
413
|
|
433
414
|
# 3. make sure that encoder returns `ModelOutput`
|