optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +30 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +9 -4
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +32 -3
- optimum/rbln/transformers/models/__init__.py +37 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +189 -90
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +186 -95
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +80 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +77 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -11
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +46 -31
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
batch_size: Optional[int] = None,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
"""
|
27
|
+
Args:
|
28
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
29
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
30
|
+
|
31
|
+
Raises:
|
32
|
+
ValueError: If batch_size is not a positive integer.
|
33
|
+
"""
|
34
|
+
super().__init__(**kwargs)
|
35
|
+
self.batch_size = batch_size or 1
|
36
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
37
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
batch_size: Optional[int] = None,
|
44
|
+
num_query_tokens: Optional[int] = None,
|
45
|
+
image_text_hidden_size: Optional[int] = None,
|
46
|
+
**kwargs,
|
47
|
+
):
|
48
|
+
"""
|
49
|
+
Args:
|
50
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
51
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
ValueError: If batch_size is not a positive integer.
|
55
|
+
"""
|
56
|
+
super().__init__(**kwargs)
|
57
|
+
self.batch_size = batch_size or 1
|
58
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
59
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
60
|
+
|
61
|
+
self.num_query_tokens = num_query_tokens
|
62
|
+
self.image_text_hidden_size = image_text_hidden_size
|
63
|
+
|
64
|
+
|
65
|
+
class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
66
|
+
submodules = ["vision_model", "qformer", "language_model"]
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
batch_size: Optional[int] = None,
|
71
|
+
vision_model: Optional[RBLNModelConfig] = None,
|
72
|
+
qformer: Optional[RBLNModelConfig] = None,
|
73
|
+
language_model: Optional[RBLNModelConfig] = None,
|
74
|
+
**kwargs,
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Args:
|
78
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
79
|
+
vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
80
|
+
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
81
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
82
|
+
|
83
|
+
Raises:
|
84
|
+
ValueError: If batch_size is not a positive integer.
|
85
|
+
"""
|
86
|
+
super().__init__(**kwargs)
|
87
|
+
self.batch_size = batch_size or 1
|
88
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
89
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
90
|
+
|
91
|
+
self.vision_model = self.init_submodule_config(RBLNBlip2VisionModelConfig, vision_model, batch_size=batch_size)
|
92
|
+
self.language_model = language_model
|
93
|
+
self.qformer = self.init_submodule_config(RBLNBlip2QFormerModelConfig, qformer, batch_size=batch_size)
|
@@ -0,0 +1,298 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from transformers import (
|
21
|
+
AutoModelForVisualQuestionAnswering,
|
22
|
+
Blip2ForConditionalGeneration,
|
23
|
+
Blip2QFormerModel,
|
24
|
+
Blip2VisionModel,
|
25
|
+
PretrainedConfig,
|
26
|
+
PreTrainedModel,
|
27
|
+
)
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
|
29
|
+
from transformers.utils import logging
|
30
|
+
|
31
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
32
|
+
from ....modeling import RBLNModel
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__)
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import (
|
39
|
+
AutoFeatureExtractor,
|
40
|
+
AutoProcessor,
|
41
|
+
AutoTokenizer,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class RBLNBlip2VisionModel(RBLNModel):
|
46
|
+
def get_input_embeddings(self):
|
47
|
+
return self.embeddings
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
51
|
+
class Blip2VisionModelWrapper(torch.nn.Module):
|
52
|
+
def __init__(self, model: "Blip2VisionModel") -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
|
56
|
+
def forward(self, *args, **kwargs):
|
57
|
+
kwargs.pop("return_dict", None)
|
58
|
+
return self.model(*args, **kwargs, return_dict=False)
|
59
|
+
|
60
|
+
return Blip2VisionModelWrapper(model).eval()
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def _update_rbln_config(
|
64
|
+
cls,
|
65
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
66
|
+
model: Optional["PreTrainedModel"] = None,
|
67
|
+
model_config: Optional["PretrainedConfig"] = None,
|
68
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
69
|
+
) -> RBLNModelConfig:
|
70
|
+
input_info = [
|
71
|
+
(
|
72
|
+
"pixel_values",
|
73
|
+
[
|
74
|
+
rbln_config.batch_size,
|
75
|
+
model_config.num_channels,
|
76
|
+
model_config.image_size,
|
77
|
+
model_config.image_size,
|
78
|
+
],
|
79
|
+
"float32",
|
80
|
+
),
|
81
|
+
]
|
82
|
+
|
83
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
84
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
85
|
+
return rbln_config
|
86
|
+
|
87
|
+
def forward(
|
88
|
+
self,
|
89
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
90
|
+
output_attentions: Optional[bool] = None,
|
91
|
+
output_hidden_states: Optional[bool] = None,
|
92
|
+
return_dict: Optional[bool] = None,
|
93
|
+
interpolate_pos_encoding: bool = False,
|
94
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
95
|
+
output = super().forward(pixel_values, return_dict=return_dict)
|
96
|
+
return output
|
97
|
+
|
98
|
+
def _prepare_output(self, output, return_dict):
|
99
|
+
"""
|
100
|
+
Prepare model output based on return_dict flag.
|
101
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
102
|
+
"""
|
103
|
+
if not return_dict:
|
104
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
105
|
+
else:
|
106
|
+
return BaseModelOutputWithPooling(
|
107
|
+
last_hidden_state=output[0],
|
108
|
+
pooler_output=output[1],
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class RBLNBlip2QFormerModel(RBLNModel):
|
113
|
+
def get_input_embeddings(self):
|
114
|
+
return self.embeddings.word_embeddings
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
118
|
+
class Blip2QFormerModelWrapper(torch.nn.Module):
|
119
|
+
def __init__(self, model: "Blip2QFormerModel"):
|
120
|
+
super().__init__()
|
121
|
+
self.model = model
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
query_embeds: torch.FloatTensor,
|
126
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
127
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
qformer_out = self.model(
|
130
|
+
query_embeds=query_embeds,
|
131
|
+
encoder_hidden_states=encoder_hidden_states,
|
132
|
+
encoder_attention_mask=encoder_attention_mask,
|
133
|
+
return_dict=False,
|
134
|
+
)
|
135
|
+
return qformer_out
|
136
|
+
|
137
|
+
return Blip2QFormerModelWrapper(model).eval()
|
138
|
+
|
139
|
+
@classmethod
|
140
|
+
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
|
141
|
+
if rbln_config.num_query_tokens is None:
|
142
|
+
rbln_config.num_query_tokens = model.config.num_query_tokens
|
143
|
+
|
144
|
+
if rbln_config.image_text_hidden_size is None:
|
145
|
+
rbln_config.image_text_hidden_size = model.config.image_text_hidden_size
|
146
|
+
|
147
|
+
return rbln_config
|
148
|
+
|
149
|
+
@classmethod
|
150
|
+
def _update_rbln_config(
|
151
|
+
cls,
|
152
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
153
|
+
model: Optional["PreTrainedModel"] = None,
|
154
|
+
model_config: Optional["PretrainedConfig"] = None,
|
155
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
156
|
+
) -> RBLNModelConfig:
|
157
|
+
input_info = [
|
158
|
+
(
|
159
|
+
"query_embeds",
|
160
|
+
[
|
161
|
+
rbln_config.batch_size,
|
162
|
+
rbln_config.num_query_tokens,
|
163
|
+
model_config.hidden_size,
|
164
|
+
],
|
165
|
+
"float32",
|
166
|
+
),
|
167
|
+
(
|
168
|
+
"encoder_hidden_states",
|
169
|
+
[
|
170
|
+
rbln_config.batch_size,
|
171
|
+
# image_text_hidden_size + cls token
|
172
|
+
rbln_config.image_text_hidden_size + 1,
|
173
|
+
model_config.encoder_hidden_size,
|
174
|
+
],
|
175
|
+
"float32",
|
176
|
+
),
|
177
|
+
(
|
178
|
+
"encoder_attention_mask",
|
179
|
+
# image_text_hidden_size + cls token
|
180
|
+
[rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
|
181
|
+
"int64",
|
182
|
+
),
|
183
|
+
]
|
184
|
+
|
185
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
186
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
187
|
+
return rbln_config
|
188
|
+
|
189
|
+
def forward(
|
190
|
+
self,
|
191
|
+
query_embeds: torch.FloatTensor,
|
192
|
+
query_length: Optional[int] = None,
|
193
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
194
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
195
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
196
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
197
|
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
198
|
+
use_cache: Optional[bool] = None,
|
199
|
+
output_attentions: Optional[bool] = None,
|
200
|
+
output_hidden_states: Optional[bool] = None,
|
201
|
+
return_dict: Optional[bool] = None,
|
202
|
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
203
|
+
output = super().forward(query_embeds, encoder_hidden_states, encoder_attention_mask, return_dict=return_dict)
|
204
|
+
return output
|
205
|
+
|
206
|
+
def _prepare_output(self, output, return_dict):
|
207
|
+
"""
|
208
|
+
Prepare model output based on return_dict flag.
|
209
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
210
|
+
"""
|
211
|
+
if not return_dict:
|
212
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
213
|
+
else:
|
214
|
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
215
|
+
last_hidden_state=output[0],
|
216
|
+
pooler_output=output[1],
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
221
|
+
auto_model_class = AutoModelForVisualQuestionAnswering
|
222
|
+
_rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
|
223
|
+
|
224
|
+
def __getattr__(self, __name: str) -> Any:
|
225
|
+
def redirect(func):
|
226
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
227
|
+
|
228
|
+
val = getattr(Blip2ForConditionalGeneration, __name)
|
229
|
+
|
230
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
231
|
+
return redirect(val)
|
232
|
+
return val
|
233
|
+
|
234
|
+
def can_generate(self):
|
235
|
+
return True
|
236
|
+
|
237
|
+
@classmethod
|
238
|
+
def save_torch_artifacts(
|
239
|
+
cls,
|
240
|
+
model: "Blip2ForConditionalGeneration",
|
241
|
+
save_dir_path: Path,
|
242
|
+
subfolder: str,
|
243
|
+
rbln_config: RBLNModelConfig,
|
244
|
+
):
|
245
|
+
"""
|
246
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
247
|
+
store the torch tensor, weight, etc. in this function.
|
248
|
+
"""
|
249
|
+
save_dict = {}
|
250
|
+
save_dict["query_tokens"] = model.query_tokens
|
251
|
+
torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
|
252
|
+
|
253
|
+
def __post_init__(self, **kwargs):
|
254
|
+
self.vision_model = self.rbln_submodules[0]
|
255
|
+
self.language_model = self.rbln_submodules[2]
|
256
|
+
self.qformer = self.rbln_submodules[1]
|
257
|
+
self.language_projection = self.model[0]
|
258
|
+
|
259
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "query_tokens.pth", weights_only=False)
|
260
|
+
self.query_tokens = artifacts["query_tokens"]
|
261
|
+
|
262
|
+
def get_attn_impl(self) -> str:
|
263
|
+
return self.rbln_config.language_model.attn_impl
|
264
|
+
|
265
|
+
def get_kvcache_num_blocks(self) -> int:
|
266
|
+
return self.rbln_config.language_model.kvcache_num_blocks
|
267
|
+
|
268
|
+
def get_input_embeddings(self):
|
269
|
+
return self.language_model.get_input_embeddings()
|
270
|
+
|
271
|
+
@classmethod
|
272
|
+
def wrap_model_if_needed(cls, model, rbln_config):
|
273
|
+
return model.language_projection
|
274
|
+
|
275
|
+
@classmethod
|
276
|
+
def _update_rbln_config(
|
277
|
+
cls,
|
278
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
279
|
+
model: Optional["PreTrainedModel"] = None,
|
280
|
+
model_config: Optional["PretrainedConfig"] = None,
|
281
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
282
|
+
) -> RBLNModelConfig:
|
283
|
+
input_info = [
|
284
|
+
(
|
285
|
+
"query_output",
|
286
|
+
[
|
287
|
+
rbln_config.batch_size,
|
288
|
+
model_config.num_query_tokens,
|
289
|
+
model_config.qformer_config.hidden_size,
|
290
|
+
],
|
291
|
+
"float32",
|
292
|
+
),
|
293
|
+
]
|
294
|
+
|
295
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
296
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
297
|
+
|
298
|
+
return rbln_config
|
@@ -12,13 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, List, Optional
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
16
16
|
|
17
17
|
import rebel
|
18
18
|
|
19
19
|
from ....configuration_utils import RBLNModelConfig
|
20
20
|
from ....utils.logging import get_logger
|
21
|
-
from ...utils.rbln_quantization import
|
21
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
22
22
|
|
23
23
|
|
24
24
|
logger = get_logger()
|
@@ -31,10 +31,11 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
31
31
|
max_seq_len: Optional[int] = None,
|
32
32
|
use_inputs_embeds: Optional[bool] = None,
|
33
33
|
use_attention_mask: Optional[bool] = None,
|
34
|
+
use_position_ids: Optional[bool] = None,
|
34
35
|
attn_impl: Optional[str] = None,
|
35
36
|
kvcache_partition_len: Optional[int] = None,
|
36
37
|
kvcache_block_size: Optional[int] = None,
|
37
|
-
quantization: Optional[Dict[str, Any]] = None,
|
38
|
+
quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
|
38
39
|
prefill_chunk_size: Optional[int] = None,
|
39
40
|
kvcache_num_blocks: Optional[int] = None,
|
40
41
|
decoder_batch_sizes: Optional[List[int]] = None,
|
@@ -47,6 +48,7 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
47
48
|
use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
|
48
49
|
use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
|
49
50
|
for RBLN-CA02 devices.
|
51
|
+
use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
|
50
52
|
attn_impl (Optional[str]): The attention implementation to use.
|
51
53
|
kvcache_partition_len (Optional[int]): The length of each KV cache partition.
|
52
54
|
kvcache_block_size (Optional[int]): The block size for KV cache.
|
@@ -74,8 +76,9 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
74
76
|
|
75
77
|
self.max_seq_len = max_seq_len
|
76
78
|
self.use_inputs_embeds = use_inputs_embeds or False
|
77
|
-
|
79
|
+
self.use_position_ids = use_position_ids or False
|
78
80
|
self.use_attention_mask = use_attention_mask
|
81
|
+
|
79
82
|
npu = self.npu or rebel.get_npu_name()
|
80
83
|
if npu == "RBLN-CA02":
|
81
84
|
if self.use_attention_mask is False:
|
@@ -84,12 +87,15 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
84
87
|
else:
|
85
88
|
self.use_attention_mask = self.use_attention_mask or False
|
86
89
|
|
90
|
+
if self.use_position_ids and not self.use_attention_mask:
|
91
|
+
raise ValueError("Position IDs should be used with attention mask.")
|
92
|
+
|
87
93
|
self.attn_impl = attn_impl
|
88
94
|
self.kvcache_partition_len = kvcache_partition_len
|
89
95
|
self.kvcache_block_size = kvcache_block_size
|
90
96
|
self.quantization = quantization or {}
|
91
|
-
if self.quantization:
|
92
|
-
|
97
|
+
if self.quantization and isinstance(self.quantization, dict):
|
98
|
+
self.quantization = RBLNQuantizationConfig(**self.quantization)
|
93
99
|
|
94
100
|
self.prefill_chunk_size = prefill_chunk_size or 128
|
95
101
|
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|