optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +20 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +22 -3
- optimum/rbln/transformers/models/__init__.py +23 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +160 -88
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +35 -24
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from transformers import (
|
21
|
+
AutoModelForVisualQuestionAnswering,
|
22
|
+
Blip2ForConditionalGeneration,
|
23
|
+
Blip2QFormerModel,
|
24
|
+
Blip2VisionModel,
|
25
|
+
PretrainedConfig,
|
26
|
+
PreTrainedModel,
|
27
|
+
)
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
|
29
|
+
from transformers.utils import logging
|
30
|
+
|
31
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
32
|
+
from ....modeling import RBLNModel
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__)
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import (
|
39
|
+
AutoFeatureExtractor,
|
40
|
+
AutoProcessor,
|
41
|
+
AutoTokenizer,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class RBLNBlip2VisionModel(RBLNModel):
|
46
|
+
def get_input_embeddings(self):
|
47
|
+
return self.embeddings
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
51
|
+
class Blip2VisionModelWrapper(torch.nn.Module):
|
52
|
+
def __init__(self, model: "Blip2VisionModel") -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
|
56
|
+
def forward(self, *args, **kwargs):
|
57
|
+
kwargs.pop("return_dict", None)
|
58
|
+
return self.model(*args, **kwargs, return_dict=False)
|
59
|
+
|
60
|
+
return Blip2VisionModelWrapper(model).eval()
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def _update_rbln_config(
|
64
|
+
cls,
|
65
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
66
|
+
model: Optional["PreTrainedModel"] = None,
|
67
|
+
model_config: Optional["PretrainedConfig"] = None,
|
68
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
69
|
+
) -> RBLNModelConfig:
|
70
|
+
input_info = [
|
71
|
+
(
|
72
|
+
"pixel_values",
|
73
|
+
[
|
74
|
+
rbln_config.batch_size,
|
75
|
+
model_config.num_channels,
|
76
|
+
model_config.image_size,
|
77
|
+
model_config.image_size,
|
78
|
+
],
|
79
|
+
"float32",
|
80
|
+
),
|
81
|
+
]
|
82
|
+
|
83
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
84
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
85
|
+
return rbln_config
|
86
|
+
|
87
|
+
def forward(
|
88
|
+
self,
|
89
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
90
|
+
output_attentions: Optional[bool] = None,
|
91
|
+
output_hidden_states: Optional[bool] = None,
|
92
|
+
return_dict: Optional[bool] = None,
|
93
|
+
interpolate_pos_encoding: bool = False,
|
94
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
95
|
+
output = super().forward(pixel_values, return_dict=return_dict)
|
96
|
+
return output
|
97
|
+
|
98
|
+
def _prepare_output(self, output, return_dict):
|
99
|
+
"""
|
100
|
+
Prepare model output based on return_dict flag.
|
101
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
102
|
+
"""
|
103
|
+
if not return_dict:
|
104
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
105
|
+
else:
|
106
|
+
return BaseModelOutputWithPooling(
|
107
|
+
last_hidden_state=output[0],
|
108
|
+
pooler_output=output[1],
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class RBLNBlip2QFormerModel(RBLNModel):
|
113
|
+
def get_input_embeddings(self):
|
114
|
+
return self.embeddings.word_embeddings
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
118
|
+
class Blip2QFormerModelWrapper(torch.nn.Module):
|
119
|
+
def __init__(self, model: "Blip2QFormerModel"):
|
120
|
+
super().__init__()
|
121
|
+
self.model = model
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
query_embeds: torch.FloatTensor,
|
126
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
127
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
qformer_out = self.model(
|
130
|
+
query_embeds=query_embeds,
|
131
|
+
encoder_hidden_states=encoder_hidden_states,
|
132
|
+
encoder_attention_mask=encoder_attention_mask,
|
133
|
+
return_dict=False,
|
134
|
+
)
|
135
|
+
return qformer_out
|
136
|
+
|
137
|
+
return Blip2QFormerModelWrapper(model).eval()
|
138
|
+
|
139
|
+
@classmethod
|
140
|
+
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
|
141
|
+
if rbln_config.num_query_tokens is None:
|
142
|
+
rbln_config.num_query_tokens = model.config.num_query_tokens
|
143
|
+
|
144
|
+
if rbln_config.image_text_hidden_size is None:
|
145
|
+
rbln_config.image_text_hidden_size = model.config.image_text_hidden_size
|
146
|
+
|
147
|
+
return rbln_config
|
148
|
+
|
149
|
+
@classmethod
|
150
|
+
def _update_rbln_config(
|
151
|
+
cls,
|
152
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
153
|
+
model: Optional["PreTrainedModel"] = None,
|
154
|
+
model_config: Optional["PretrainedConfig"] = None,
|
155
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
156
|
+
) -> RBLNModelConfig:
|
157
|
+
input_info = [
|
158
|
+
(
|
159
|
+
"query_embeds",
|
160
|
+
[
|
161
|
+
rbln_config.batch_size,
|
162
|
+
rbln_config.num_query_tokens,
|
163
|
+
model_config.hidden_size,
|
164
|
+
],
|
165
|
+
"float32",
|
166
|
+
),
|
167
|
+
(
|
168
|
+
"encoder_hidden_states",
|
169
|
+
[
|
170
|
+
rbln_config.batch_size,
|
171
|
+
# image_text_hidden_size + cls token
|
172
|
+
rbln_config.image_text_hidden_size + 1,
|
173
|
+
model_config.encoder_hidden_size,
|
174
|
+
],
|
175
|
+
"float32",
|
176
|
+
),
|
177
|
+
(
|
178
|
+
"encoder_attention_mask",
|
179
|
+
# image_text_hidden_size + cls token
|
180
|
+
[rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
|
181
|
+
"int64",
|
182
|
+
),
|
183
|
+
]
|
184
|
+
|
185
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
186
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
187
|
+
return rbln_config
|
188
|
+
|
189
|
+
def forward(
|
190
|
+
self,
|
191
|
+
query_embeds: torch.FloatTensor,
|
192
|
+
query_length: Optional[int] = None,
|
193
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
194
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
195
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
196
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
197
|
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
198
|
+
use_cache: Optional[bool] = None,
|
199
|
+
output_attentions: Optional[bool] = None,
|
200
|
+
output_hidden_states: Optional[bool] = None,
|
201
|
+
return_dict: Optional[bool] = None,
|
202
|
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
203
|
+
output = super().forward(query_embeds, encoder_hidden_states, encoder_attention_mask, return_dict=return_dict)
|
204
|
+
return output
|
205
|
+
|
206
|
+
def _prepare_output(self, output, return_dict):
|
207
|
+
"""
|
208
|
+
Prepare model output based on return_dict flag.
|
209
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
210
|
+
"""
|
211
|
+
if not return_dict:
|
212
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
213
|
+
else:
|
214
|
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
215
|
+
last_hidden_state=output[0],
|
216
|
+
pooler_output=output[1],
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
221
|
+
auto_model_class = AutoModelForVisualQuestionAnswering
|
222
|
+
_rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
|
223
|
+
|
224
|
+
def __getattr__(self, __name: str) -> Any:
|
225
|
+
def redirect(func):
|
226
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
227
|
+
|
228
|
+
val = getattr(Blip2ForConditionalGeneration, __name)
|
229
|
+
|
230
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
231
|
+
return redirect(val)
|
232
|
+
return val
|
233
|
+
|
234
|
+
def can_generate(self):
|
235
|
+
return True
|
236
|
+
|
237
|
+
@classmethod
|
238
|
+
def save_torch_artifacts(
|
239
|
+
cls,
|
240
|
+
model: "Blip2ForConditionalGeneration",
|
241
|
+
save_dir_path: Path,
|
242
|
+
subfolder: str,
|
243
|
+
rbln_config: RBLNModelConfig,
|
244
|
+
):
|
245
|
+
"""
|
246
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
247
|
+
store the torch tensor, weight, etc. in this function.
|
248
|
+
"""
|
249
|
+
save_dict = {}
|
250
|
+
save_dict["query_tokens"] = model.query_tokens
|
251
|
+
torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
|
252
|
+
|
253
|
+
def __post_init__(self, **kwargs):
|
254
|
+
self.vision_model = self.rbln_submodules[0]
|
255
|
+
self.language_model = self.rbln_submodules[2]
|
256
|
+
self.qformer = self.rbln_submodules[1]
|
257
|
+
self.language_projection = self.model[0]
|
258
|
+
|
259
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "query_tokens.pth", weights_only=False)
|
260
|
+
self.query_tokens = artifacts["query_tokens"]
|
261
|
+
|
262
|
+
def get_attn_impl(self) -> str:
|
263
|
+
return self.rbln_config.language_model.attn_impl
|
264
|
+
|
265
|
+
def get_kvcache_num_blocks(self) -> int:
|
266
|
+
return self.rbln_config.language_model.kvcache_num_blocks
|
267
|
+
|
268
|
+
def get_input_embeddings(self):
|
269
|
+
return self.language_model.get_input_embeddings()
|
270
|
+
|
271
|
+
@classmethod
|
272
|
+
def wrap_model_if_needed(cls, model, rbln_config):
|
273
|
+
return model.language_projection
|
274
|
+
|
275
|
+
@classmethod
|
276
|
+
def _update_rbln_config(
|
277
|
+
cls,
|
278
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
279
|
+
model: Optional["PreTrainedModel"] = None,
|
280
|
+
model_config: Optional["PretrainedConfig"] = None,
|
281
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
282
|
+
) -> RBLNModelConfig:
|
283
|
+
input_info = [
|
284
|
+
(
|
285
|
+
"query_output",
|
286
|
+
[
|
287
|
+
rbln_config.batch_size,
|
288
|
+
model_config.num_query_tokens,
|
289
|
+
model_config.qformer_config.hidden_size,
|
290
|
+
],
|
291
|
+
"float32",
|
292
|
+
),
|
293
|
+
]
|
294
|
+
|
295
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
296
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
297
|
+
|
298
|
+
return rbln_config
|
@@ -12,13 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, List, Optional
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
16
16
|
|
17
17
|
import rebel
|
18
18
|
|
19
19
|
from ....configuration_utils import RBLNModelConfig
|
20
20
|
from ....utils.logging import get_logger
|
21
|
-
from ...utils.rbln_quantization import
|
21
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
22
22
|
|
23
23
|
|
24
24
|
logger = get_logger()
|
@@ -31,10 +31,11 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
31
31
|
max_seq_len: Optional[int] = None,
|
32
32
|
use_inputs_embeds: Optional[bool] = None,
|
33
33
|
use_attention_mask: Optional[bool] = None,
|
34
|
+
use_position_ids: Optional[bool] = None,
|
34
35
|
attn_impl: Optional[str] = None,
|
35
36
|
kvcache_partition_len: Optional[int] = None,
|
36
37
|
kvcache_block_size: Optional[int] = None,
|
37
|
-
quantization: Optional[Dict[str, Any]] = None,
|
38
|
+
quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
|
38
39
|
prefill_chunk_size: Optional[int] = None,
|
39
40
|
kvcache_num_blocks: Optional[int] = None,
|
40
41
|
decoder_batch_sizes: Optional[List[int]] = None,
|
@@ -47,6 +48,7 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
47
48
|
use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
|
48
49
|
use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
|
49
50
|
for RBLN-CA02 devices.
|
51
|
+
use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
|
50
52
|
attn_impl (Optional[str]): The attention implementation to use.
|
51
53
|
kvcache_partition_len (Optional[int]): The length of each KV cache partition.
|
52
54
|
kvcache_block_size (Optional[int]): The block size for KV cache.
|
@@ -74,8 +76,9 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
74
76
|
|
75
77
|
self.max_seq_len = max_seq_len
|
76
78
|
self.use_inputs_embeds = use_inputs_embeds or False
|
77
|
-
|
79
|
+
self.use_position_ids = use_position_ids or False
|
78
80
|
self.use_attention_mask = use_attention_mask
|
81
|
+
|
79
82
|
npu = self.npu or rebel.get_npu_name()
|
80
83
|
if npu == "RBLN-CA02":
|
81
84
|
if self.use_attention_mask is False:
|
@@ -84,12 +87,15 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
84
87
|
else:
|
85
88
|
self.use_attention_mask = self.use_attention_mask or False
|
86
89
|
|
90
|
+
if self.use_position_ids and not self.use_attention_mask:
|
91
|
+
raise ValueError("Position IDs should be used with attention mask.")
|
92
|
+
|
87
93
|
self.attn_impl = attn_impl
|
88
94
|
self.kvcache_partition_len = kvcache_partition_len
|
89
95
|
self.kvcache_block_size = kvcache_block_size
|
90
96
|
self.quantization = quantization or {}
|
91
|
-
if self.quantization:
|
92
|
-
|
97
|
+
if self.quantization and isinstance(self.quantization, dict):
|
98
|
+
self.quantization = RBLNQuantizationConfig(**self.quantization)
|
93
99
|
|
94
100
|
self.prefill_chunk_size = prefill_chunk_size or 128
|
95
101
|
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|
@@ -146,7 +146,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
146
146
|
max_seq_len: int,
|
147
147
|
use_rotary_emb: bool,
|
148
148
|
attn_impl: str,
|
149
|
+
use_inputs_embeds: bool,
|
149
150
|
use_attention_mask: bool,
|
151
|
+
use_position_ids: bool,
|
152
|
+
use_learned_pos_emb: Optional[bool] = None,
|
150
153
|
kvcache_partition_len: Optional[int] = None,
|
151
154
|
kvcache_block_size: Optional[int] = None,
|
152
155
|
):
|
@@ -161,6 +164,10 @@ class DecoderOnlyWrapper(nn.Module):
|
|
161
164
|
self.attn_impl = attn_impl
|
162
165
|
self.kvcache_block_size = kvcache_block_size
|
163
166
|
self.use_attention_mask = use_attention_mask
|
167
|
+
self.use_position_ids = use_position_ids
|
168
|
+
self.use_inputs_embeds = use_inputs_embeds
|
169
|
+
self.use_learned_pos_emb = use_learned_pos_emb
|
170
|
+
|
164
171
|
if self.attn_impl == "flash_attn":
|
165
172
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
166
173
|
elif self.attn_impl == "eager":
|
@@ -209,6 +216,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
209
216
|
partition_len=self.kvcache_partition_len,
|
210
217
|
max_seq_len=max_seq_len,
|
211
218
|
kvcache_block_size=self.kvcache_block_size,
|
219
|
+
use_learned_pos_emb=self.use_learned_pos_emb,
|
212
220
|
)
|
213
221
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
214
222
|
return new_causal_lm
|
@@ -222,24 +230,16 @@ class DecoderOnlyWrapper(nn.Module):
|
|
222
230
|
self._phase = phase
|
223
231
|
self.causal_lm.phase = phase
|
224
232
|
|
225
|
-
def
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
if input_ids_or_inputs_embeds.ndim == 2:
|
236
|
-
input_ids = input_ids_or_inputs_embeds
|
237
|
-
inputs_embeds = None
|
238
|
-
elif input_ids_or_inputs_embeds.ndim == 3:
|
239
|
-
input_ids = None
|
240
|
-
inputs_embeds = input_ids_or_inputs_embeds
|
241
|
-
else:
|
242
|
-
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
233
|
+
def prepare_forward_args(self, *args):
|
234
|
+
args = list(args)
|
235
|
+
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
236
|
+
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
237
|
+
cache_position = args.pop(0)
|
238
|
+
block_tables = args.pop(0)
|
239
|
+
query_position = args.pop(0) if self.phase == "prefill" else None
|
240
|
+
attention_mask = args.pop(0) if self.use_attention_mask else None
|
241
|
+
position_ids = args.pop(0) if self.use_position_ids else None
|
242
|
+
past_key_values = args
|
243
243
|
|
244
244
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
245
245
|
raise ValueError(
|
@@ -256,11 +256,37 @@ class DecoderOnlyWrapper(nn.Module):
|
|
256
256
|
_past_key_values.append(past_key_value)
|
257
257
|
past_key_values = _past_key_values
|
258
258
|
|
259
|
+
return (
|
260
|
+
input_ids,
|
261
|
+
inputs_embeds,
|
262
|
+
cache_position,
|
263
|
+
block_tables,
|
264
|
+
query_position,
|
265
|
+
attention_mask,
|
266
|
+
position_ids,
|
267
|
+
past_key_values,
|
268
|
+
self.rotary_emb,
|
269
|
+
)
|
270
|
+
|
271
|
+
def forward(self, *args):
|
272
|
+
(
|
273
|
+
input_ids,
|
274
|
+
inputs_embeds,
|
275
|
+
cache_position,
|
276
|
+
block_tables,
|
277
|
+
query_position,
|
278
|
+
attention_mask,
|
279
|
+
position_ids,
|
280
|
+
past_key_values,
|
281
|
+
rotary_emb,
|
282
|
+
) = self.prepare_forward_args(*args)
|
283
|
+
|
259
284
|
logit = self.causal_lm(
|
260
285
|
input_ids=input_ids,
|
261
286
|
inputs_embeds=inputs_embeds,
|
262
287
|
attention_mask=attention_mask,
|
263
288
|
cache_position=cache_position,
|
289
|
+
position_ids=position_ids,
|
264
290
|
query_position=query_position,
|
265
291
|
past_key_values=past_key_values,
|
266
292
|
rotary_emb=rotary_emb,
|
@@ -269,58 +295,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
269
295
|
|
270
296
|
return logit
|
271
297
|
|
272
|
-
def forward(self, *args):
|
273
|
-
if self.phase == "decode":
|
274
|
-
if self.use_attention_mask:
|
275
|
-
(
|
276
|
-
input_ids_or_inputs_embeds,
|
277
|
-
cache_position,
|
278
|
-
attention_mask,
|
279
|
-
block_tables,
|
280
|
-
*past_key_values,
|
281
|
-
) = args
|
282
|
-
else:
|
283
|
-
(
|
284
|
-
input_ids_or_inputs_embeds,
|
285
|
-
cache_position,
|
286
|
-
block_tables,
|
287
|
-
*past_key_values,
|
288
|
-
) = args
|
289
|
-
attention_mask = None
|
290
|
-
query_position = None
|
291
|
-
elif self.phase == "prefill":
|
292
|
-
if self.use_attention_mask:
|
293
|
-
(
|
294
|
-
input_ids_or_inputs_embeds,
|
295
|
-
cache_position,
|
296
|
-
attention_mask,
|
297
|
-
query_position,
|
298
|
-
block_tables,
|
299
|
-
*past_key_values,
|
300
|
-
) = args
|
301
|
-
else:
|
302
|
-
(
|
303
|
-
input_ids_or_inputs_embeds,
|
304
|
-
cache_position,
|
305
|
-
query_position,
|
306
|
-
block_tables,
|
307
|
-
*past_key_values,
|
308
|
-
) = args
|
309
|
-
attention_mask = None
|
310
|
-
|
311
|
-
else:
|
312
|
-
raise ValueError(f"Unknown phase: {self.phase}")
|
313
|
-
|
314
|
-
return self.forward_common(
|
315
|
-
input_ids_or_inputs_embeds,
|
316
|
-
cache_position,
|
317
|
-
attention_mask,
|
318
|
-
query_position,
|
319
|
-
block_tables,
|
320
|
-
self.rotary_emb,
|
321
|
-
*past_key_values,
|
322
|
-
)
|
323
|
-
|
324
298
|
|
325
299
|
class DecoderOnlyForCausalLM(nn.Module):
|
326
300
|
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
@@ -367,6 +341,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
367
341
|
inputs_embeds: torch.Tensor = None,
|
368
342
|
attention_mask: torch.Tensor = None,
|
369
343
|
cache_position: torch.Tensor = None,
|
344
|
+
position_ids: torch.Tensor = None,
|
370
345
|
query_position: torch.Tensor = None,
|
371
346
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
372
347
|
rotary_emb: nn.Module = None,
|
@@ -378,6 +353,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
378
353
|
inputs_embeds=inputs_embeds,
|
379
354
|
attention_mask=attention_mask,
|
380
355
|
cache_position=cache_position,
|
356
|
+
position_ids=position_ids,
|
381
357
|
past_key_values=past_key_values,
|
382
358
|
rotary_emb=rotary_emb,
|
383
359
|
block_tables=block_tables,
|
@@ -404,7 +380,13 @@ class DecoderOnlyModel(nn.Module):
|
|
404
380
|
"""
|
405
381
|
|
406
382
|
def __init__(
|
407
|
-
self,
|
383
|
+
self,
|
384
|
+
model,
|
385
|
+
layers: List["DecoderOnlyLayer"],
|
386
|
+
partition_len=None,
|
387
|
+
max_seq_len=None,
|
388
|
+
kvcache_block_size=None,
|
389
|
+
use_learned_pos_emb=None,
|
408
390
|
):
|
409
391
|
super().__init__()
|
410
392
|
self._original_mod = model
|
@@ -413,6 +395,7 @@ class DecoderOnlyModel(nn.Module):
|
|
413
395
|
self.partition_len = partition_len
|
414
396
|
self.kvcache_block_size = kvcache_block_size
|
415
397
|
self.max_seq_len = max_seq_len
|
398
|
+
self.use_learned_pos_emb = use_learned_pos_emb
|
416
399
|
|
417
400
|
@property
|
418
401
|
def phase(self):
|
@@ -457,11 +440,12 @@ class DecoderOnlyModel(nn.Module):
|
|
457
440
|
def forward(
|
458
441
|
self,
|
459
442
|
input_ids: torch.Tensor = None,
|
460
|
-
inputs_embeds: torch.Tensor = None,
|
443
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
461
444
|
attention_mask: torch.Tensor = None,
|
462
445
|
cache_position: torch.Tensor = None,
|
446
|
+
position_ids: torch.Tensor = None,
|
463
447
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
464
|
-
rotary_emb: nn.Module = None,
|
448
|
+
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
465
449
|
block_tables: Optional[torch.Tensor] = None,
|
466
450
|
):
|
467
451
|
# retrieve input_ids and inputs_embeds
|
@@ -477,24 +461,38 @@ class DecoderOnlyModel(nn.Module):
|
|
477
461
|
hidden_states = inputs_embeds * self.hidden_multiplier
|
478
462
|
|
479
463
|
# get cos,sin vector if needed
|
464
|
+
position_ids = position_ids if position_ids is not None else cache_position
|
480
465
|
if rotary_emb is not None:
|
481
466
|
if isinstance(rotary_emb, torch.Tensor):
|
482
467
|
cos = rotary_emb[0]
|
483
468
|
sin = rotary_emb[1]
|
484
469
|
else:
|
485
470
|
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
486
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin,
|
471
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
472
|
+
|
473
|
+
elif self.use_learned_pos_emb:
|
474
|
+
batch_size = inputs_embeds.shape[0]
|
475
|
+
hidden_all = []
|
476
|
+
for i in range(batch_size):
|
477
|
+
positions_idx = position_ids[i]
|
478
|
+
position_weight = self.get_pos_embedding().weight[2:]
|
479
|
+
position = position_weight[positions_idx]
|
480
|
+
batch_hidden = position + inputs_embeds[i]
|
481
|
+
hidden_all.append(batch_hidden)
|
482
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
483
|
+
cos, sin = None, None
|
484
|
+
|
487
485
|
else:
|
488
486
|
batch_size = inputs_embeds.shape[0]
|
489
|
-
if
|
487
|
+
if position_ids.shape[0] > 1:
|
490
488
|
position_embeds = []
|
491
489
|
for b_idx in range(batch_size):
|
492
|
-
position_embed = self.get_pos_embedding()(
|
490
|
+
position_embed = self.get_pos_embedding()(position_ids[b_idx])
|
493
491
|
position_embeds.append(position_embed)
|
494
492
|
|
495
493
|
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
496
494
|
else:
|
497
|
-
position_embeds = self.get_pos_embedding()(
|
495
|
+
position_embeds = self.get_pos_embedding()(position_ids)
|
498
496
|
hidden_states = hidden_states + position_embeds
|
499
497
|
cos, sin = None, None
|
500
498
|
|
@@ -798,6 +796,7 @@ class AttentionOp(nn.Module):
|
|
798
796
|
scale=scale,
|
799
797
|
block_table=block_tables,
|
800
798
|
block_size=block_size,
|
799
|
+
mask=None,
|
801
800
|
)
|
802
801
|
|
803
802
|
else:
|
@@ -825,6 +824,8 @@ class AttentionOp(nn.Module):
|
|
825
824
|
scale=scale,
|
826
825
|
block_table=block_tables,
|
827
826
|
block_size=block_size,
|
827
|
+
is_bidirectional=False,
|
828
|
+
mask=None,
|
828
829
|
)
|
829
830
|
|
830
831
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -1058,6 +1059,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1058
1059
|
block_table=block_tables,
|
1059
1060
|
block_size=kvcache_block_size,
|
1060
1061
|
partition=self.kvcache_partition_size,
|
1062
|
+
mask=None,
|
1061
1063
|
)
|
1062
1064
|
else:
|
1063
1065
|
if self.use_attention_mask:
|
@@ -1086,6 +1088,8 @@ class FlashAttentionOp(AttentionOp):
|
|
1086
1088
|
block_table=block_tables,
|
1087
1089
|
block_size=kvcache_block_size,
|
1088
1090
|
partition=self.kvcache_partition_size,
|
1091
|
+
is_bidirectional=False,
|
1092
|
+
mask=None,
|
1089
1093
|
)
|
1090
1094
|
|
1091
1095
|
# reshape for removing repeat_kv
|