optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -12,28 +12,24 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING,
|
15
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
-
from transformers import
|
19
|
-
CLIPTextConfig,
|
20
|
-
CLIPTextModel,
|
21
|
-
CLIPVisionConfig,
|
22
|
-
CLIPVisionModel,
|
23
|
-
)
|
24
|
-
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
18
|
+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
|
25
19
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
|
26
20
|
|
27
|
-
from ....
|
21
|
+
from ....configuration_utils import RBLNCompileConfig
|
28
22
|
from ....modeling import RBLNModel
|
29
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
30
23
|
from ....utils.logging import get_logger
|
24
|
+
from .configuration_clip import RBLNCLIPTextModelConfig, RBLNCLIPVisionModelConfig
|
31
25
|
|
32
26
|
|
33
27
|
logger = get_logger(__name__)
|
34
28
|
|
35
29
|
if TYPE_CHECKING:
|
36
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
|
30
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel, PreTrainedModel
|
31
|
+
|
32
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
37
33
|
|
38
34
|
|
39
35
|
class _TextEncoder(torch.nn.Module):
|
@@ -48,53 +44,55 @@ class _TextEncoder(torch.nn.Module):
|
|
48
44
|
|
49
45
|
class RBLNCLIPTextModel(RBLNModel):
|
50
46
|
@classmethod
|
51
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
47
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
|
52
48
|
return _TextEncoder(model).eval()
|
53
49
|
|
54
50
|
@classmethod
|
55
|
-
def update_rbln_config_using_pipe(
|
51
|
+
def update_rbln_config_using_pipe(
|
52
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_config: str
|
53
|
+
) -> "RBLNDiffusionMixinConfig":
|
56
54
|
return rbln_config
|
57
55
|
|
58
56
|
@classmethod
|
59
|
-
def
|
57
|
+
def _update_rbln_config(
|
60
58
|
cls,
|
61
59
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
) ->
|
66
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
67
|
-
if rbln_batch_size is None:
|
68
|
-
rbln_batch_size = 1
|
69
|
-
|
70
|
-
model_config.return_dict = False
|
71
|
-
|
60
|
+
model: Optional["PreTrainedModel"] = None,
|
61
|
+
model_config: "CLIPTextConfig" = None,
|
62
|
+
rbln_config: Optional[RBLNCLIPTextModelConfig] = None,
|
63
|
+
) -> RBLNCLIPTextModelConfig:
|
72
64
|
input_info = [
|
73
65
|
(
|
74
66
|
"input_ids",
|
75
67
|
[
|
76
|
-
|
68
|
+
rbln_config.batch_size,
|
77
69
|
model_config.max_position_embeddings,
|
78
70
|
],
|
79
71
|
"int64",
|
80
72
|
),
|
81
73
|
]
|
82
74
|
|
83
|
-
|
84
|
-
rbln_config = RBLNConfig(
|
85
|
-
rbln_cls=cls.__name__,
|
86
|
-
compile_cfgs=[rbln_compile_config],
|
87
|
-
rbln_kwargs=rbln_kwargs,
|
88
|
-
)
|
75
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
89
76
|
return rbln_config
|
90
77
|
|
91
|
-
def forward(self, input_ids:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
78
|
+
def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
|
79
|
+
# To ignore using attention_mask, we override forward method.
|
80
|
+
output = super().forward(input_ids, return_dict=return_dict)
|
81
|
+
return output
|
82
|
+
|
83
|
+
def _prepare_output(self, output, return_dict):
|
84
|
+
"""
|
85
|
+
Prepare model output based on return_dict flag.
|
86
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
87
|
+
"""
|
88
|
+
if not return_dict:
|
89
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
90
|
+
else:
|
91
|
+
return CLIPTextModelOutput(
|
92
|
+
text_embeds=output[0],
|
93
|
+
last_hidden_state=output[1],
|
94
|
+
hidden_states=output[2:],
|
95
|
+
)
|
98
96
|
|
99
97
|
|
100
98
|
class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
|
@@ -113,30 +111,30 @@ class _VisionEncoder(torch.nn.Module):
|
|
113
111
|
|
114
112
|
class RBLNCLIPVisionModel(RBLNModel):
|
115
113
|
@classmethod
|
116
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
114
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
|
117
115
|
return _VisionEncoder(model).eval()
|
118
116
|
|
119
117
|
@classmethod
|
120
|
-
def update_rbln_config_using_pipe(
|
118
|
+
def update_rbln_config_using_pipe(
|
119
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
120
|
+
) -> "RBLNDiffusionMixinConfig":
|
121
121
|
return rbln_config
|
122
122
|
|
123
123
|
@classmethod
|
124
|
-
def
|
124
|
+
def _update_rbln_config(
|
125
125
|
cls,
|
126
126
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
if rbln_image_size is None:
|
134
|
-
rbln_image_size = getattr(model_config, "image_size", None)
|
127
|
+
model: Optional["PreTrainedModel"] = None,
|
128
|
+
model_config: "CLIPVisionConfig" = None,
|
129
|
+
rbln_config: Optional[RBLNCLIPVisionModelConfig] = None,
|
130
|
+
) -> RBLNCLIPVisionModelConfig:
|
131
|
+
if rbln_config.image_size is None:
|
132
|
+
rbln_config.image_size = getattr(model_config, "image_size", None)
|
135
133
|
|
136
|
-
if isinstance(
|
137
|
-
|
134
|
+
if isinstance(rbln_config.image_size, int):
|
135
|
+
rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
|
138
136
|
|
139
|
-
if
|
137
|
+
if rbln_config.image_size is None:
|
140
138
|
raise ValueError("`rbln_image_size` should be specified!")
|
141
139
|
|
142
140
|
rbln_compile_config = RBLNCompileConfig(
|
@@ -144,45 +142,44 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
144
142
|
(
|
145
143
|
"pixel_values",
|
146
144
|
[
|
147
|
-
|
145
|
+
rbln_config.batch_size,
|
148
146
|
3,
|
149
|
-
|
150
|
-
|
147
|
+
rbln_config.image_height,
|
148
|
+
rbln_config.image_width,
|
151
149
|
],
|
152
150
|
"float32",
|
153
151
|
)
|
154
152
|
]
|
155
153
|
)
|
156
154
|
|
157
|
-
rbln_config
|
158
|
-
rbln_cls=cls.__name__,
|
159
|
-
compile_cfgs=[rbln_compile_config],
|
160
|
-
rbln_kwargs=rbln_kwargs,
|
161
|
-
)
|
162
|
-
|
163
|
-
rbln_config.model_cfg.update(
|
164
|
-
{
|
165
|
-
"batch_size": rbln_batch_size,
|
166
|
-
"image_size": rbln_image_size,
|
167
|
-
}
|
168
|
-
)
|
169
|
-
|
155
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
170
156
|
return rbln_config
|
171
157
|
|
172
158
|
def forward(
|
173
159
|
self,
|
174
160
|
pixel_values: Optional[torch.FloatTensor] = None,
|
161
|
+
return_dict: bool = None,
|
175
162
|
**kwargs,
|
176
|
-
) -> Union[Tuple,
|
163
|
+
) -> Union[Tuple, CLIPVisionModelOutput]:
|
177
164
|
if len(kwargs) > 0 and any(kwargs.values()):
|
178
165
|
logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
|
179
166
|
|
180
|
-
output = super().forward(pixel_values)
|
181
|
-
return
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
167
|
+
output = super().forward(pixel_values, return_dict=return_dict)
|
168
|
+
return output
|
169
|
+
|
170
|
+
def _prepare_output(self, output, return_dict):
|
171
|
+
"""
|
172
|
+
Prepare model output based on return_dict flag.
|
173
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
174
|
+
"""
|
175
|
+
if not return_dict:
|
176
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
177
|
+
else:
|
178
|
+
return CLIPVisionModelOutput(
|
179
|
+
image_embeds=output[0],
|
180
|
+
last_hidden_state=output[1],
|
181
|
+
hidden_states=output[2:],
|
182
|
+
)
|
186
183
|
|
187
184
|
|
188
185
|
class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
|
19
|
+
from ....configuration_utils import RBLNModelConfig
|
20
|
+
from ....utils.logging import get_logger
|
21
|
+
from ...utils.rbln_quantization import QuantizationManager
|
22
|
+
|
23
|
+
|
24
|
+
logger = get_logger()
|
25
|
+
|
26
|
+
|
27
|
+
class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
batch_size: Optional[int] = None,
|
31
|
+
max_seq_len: Optional[int] = None,
|
32
|
+
use_inputs_embeds: Optional[bool] = None,
|
33
|
+
use_attention_mask: Optional[bool] = None,
|
34
|
+
attn_impl: Optional[str] = None,
|
35
|
+
kvcache_partition_len: Optional[int] = None,
|
36
|
+
kvcache_block_size: Optional[int] = None,
|
37
|
+
quantization: Optional[Dict[str, Any]] = None,
|
38
|
+
prefill_chunk_size: Optional[int] = None,
|
39
|
+
kvcache_num_blocks: Optional[int] = None,
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
"""
|
43
|
+
Args:
|
44
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
45
|
+
max_seq_len (Optional[int]): The maximum sequence length supported by the model.
|
46
|
+
use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
|
47
|
+
use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
|
48
|
+
for RBLN-CA02 devices.
|
49
|
+
attn_impl (Optional[str]): The attention implementation to use.
|
50
|
+
kvcache_partition_len (Optional[int]): The length of each KV cache partition.
|
51
|
+
kvcache_block_size (Optional[int]): The block size for KV cache.
|
52
|
+
quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
|
53
|
+
prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
|
54
|
+
and must be a positive integer divisible by 64.
|
55
|
+
kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
|
56
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
57
|
+
|
58
|
+
Raises:
|
59
|
+
ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
|
60
|
+
a positive integer divisible by 64.
|
61
|
+
"""
|
62
|
+
super().__init__(**kwargs)
|
63
|
+
self.batch_size = batch_size or 1
|
64
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
65
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
66
|
+
|
67
|
+
self.max_seq_len = max_seq_len
|
68
|
+
self.use_inputs_embeds = use_inputs_embeds or False
|
69
|
+
|
70
|
+
self.use_attention_mask = use_attention_mask
|
71
|
+
npu = self.npu or rebel.get_npu_name()
|
72
|
+
if npu == "RBLN-CA02":
|
73
|
+
if self.use_attention_mask is False:
|
74
|
+
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
75
|
+
self.use_attention_mask = True
|
76
|
+
else:
|
77
|
+
self.use_attention_mask = self.use_attention_mask or False
|
78
|
+
|
79
|
+
self.attn_impl = attn_impl
|
80
|
+
self.kvcache_partition_len = kvcache_partition_len
|
81
|
+
self.kvcache_block_size = kvcache_block_size
|
82
|
+
self.quantization = quantization or {}
|
83
|
+
if self.quantization:
|
84
|
+
QuantizationManager.validate_quantization_config(self.quantization)
|
85
|
+
|
86
|
+
self.prefill_chunk_size = prefill_chunk_size or 128
|
87
|
+
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|
88
|
+
raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
|
89
|
+
|
90
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
@@ -32,30 +32,39 @@ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
32
32
|
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
33
33
|
|
34
34
|
|
35
|
-
def
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
35
|
+
def set_default_values(
|
36
|
+
attn_impl: Optional[str] = None,
|
37
|
+
kvcache_partition_len: Optional[int] = None,
|
38
|
+
kvcache_block_size: Optional[int] = None,
|
39
|
+
max_seq_len: Optional[int] = None,
|
40
|
+
) -> Tuple[str, int, int]:
|
41
|
+
if attn_impl is None:
|
42
|
+
attn_impl = "eager"
|
43
|
+
|
44
|
+
if kvcache_partition_len is not None:
|
45
|
+
if attn_impl == "eager":
|
46
|
+
attn_impl = "flash_attn"
|
47
47
|
logger.warning(
|
48
|
-
"A non-null `
|
49
|
-
"Since KV cache partitioning is only supported with flash attention, "
|
50
|
-
"`
|
48
|
+
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
49
|
+
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
50
|
+
"`attn_impl` has been automatically switched to 'flash_attn'."
|
51
51
|
)
|
52
52
|
|
53
|
-
|
54
|
-
|
55
|
-
|
53
|
+
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
54
|
+
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
55
|
+
|
56
|
+
if kvcache_block_size is None:
|
57
|
+
if attn_impl == "eager":
|
58
|
+
kvcache_block_size = max_seq_len
|
59
|
+
else:
|
60
|
+
kvcache_block_size = kvcache_partition_len
|
61
|
+
|
62
|
+
return attn_impl, kvcache_partition_len, kvcache_block_size
|
56
63
|
|
57
|
-
|
58
|
-
|
64
|
+
|
65
|
+
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
66
|
+
if attn_impl not in ["eager", "flash_attn"]:
|
67
|
+
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
59
68
|
|
60
69
|
## Checking Constraints...
|
61
70
|
# Constraint of eager attention:
|
@@ -65,47 +74,45 @@ def validate_attention_method(
|
|
65
74
|
# 1. `max_seq_len` should be multiple of `partition_len`.
|
66
75
|
# 2. 4k <= `partition_len` <= 32k.
|
67
76
|
# 3. `max_seq_len` should be larger then 8k.
|
68
|
-
if
|
77
|
+
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
69
78
|
raise ValueError(
|
70
|
-
f"`
|
79
|
+
f"`max_seq_len` is set to {max_seq_len}, "
|
71
80
|
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
72
|
-
f"Please reduce the `
|
73
|
-
" or consider switching `
|
81
|
+
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
82
|
+
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
74
83
|
)
|
75
84
|
|
76
|
-
if
|
77
|
-
if
|
85
|
+
if attn_impl == "flash_attn":
|
86
|
+
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
78
87
|
raise ValueError(
|
79
|
-
f"`
|
88
|
+
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
80
89
|
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
81
90
|
)
|
82
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <=
|
91
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
83
92
|
raise ValueError(
|
84
|
-
f"`
|
85
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `
|
93
|
+
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
94
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
86
95
|
f"Please provide a valid value within this range."
|
87
96
|
)
|
88
|
-
elif
|
97
|
+
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
89
98
|
raise ValueError(
|
90
|
-
f"`
|
91
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `
|
92
|
-
"this requirement, or consider switching `
|
99
|
+
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
100
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
101
|
+
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
93
102
|
)
|
94
103
|
|
95
|
-
if
|
96
|
-
if
|
104
|
+
if kvcache_block_size is not None:
|
105
|
+
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
97
106
|
raise ValueError(
|
98
|
-
f" When using 'flash attention', the `
|
99
|
-
f"must always be set equal to the `
|
107
|
+
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
108
|
+
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
100
109
|
)
|
101
|
-
elif
|
110
|
+
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
102
111
|
raise ValueError(
|
103
|
-
f" When using 'eager attention', the `
|
104
|
-
f"must always be set equal to the `
|
112
|
+
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
113
|
+
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
105
114
|
)
|
106
115
|
|
107
|
-
return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
|
108
|
-
|
109
116
|
|
110
117
|
class DecoderOnlyWrapper(nn.Module):
|
111
118
|
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
@@ -213,6 +220,53 @@ class DecoderOnlyWrapper(nn.Module):
|
|
213
220
|
self._phase = phase
|
214
221
|
self.causal_lm.phase = phase
|
215
222
|
|
223
|
+
def forward_common(
|
224
|
+
self,
|
225
|
+
input_ids_or_inputs_embeds: torch.Tensor,
|
226
|
+
cache_position: torch.Tensor,
|
227
|
+
attention_mask: torch.Tensor,
|
228
|
+
query_position: torch.Tensor,
|
229
|
+
block_tables: torch.Tensor,
|
230
|
+
rotary_emb: Union[nn.Module, torch.Tensor],
|
231
|
+
*past_key_values: List[torch.Tensor],
|
232
|
+
):
|
233
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
234
|
+
input_ids = input_ids_or_inputs_embeds
|
235
|
+
inputs_embeds = None
|
236
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
237
|
+
input_ids = None
|
238
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
239
|
+
else:
|
240
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
241
|
+
|
242
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
243
|
+
raise ValueError(
|
244
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
245
|
+
)
|
246
|
+
|
247
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
248
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
249
|
+
_past_key_values = []
|
250
|
+
for i in range(self.config.num_hidden_layers):
|
251
|
+
key_states = past_key_values[i * 2]
|
252
|
+
value_states = past_key_values[i * 2 + 1]
|
253
|
+
past_key_value = [key_states, value_states]
|
254
|
+
_past_key_values.append(past_key_value)
|
255
|
+
past_key_values = _past_key_values
|
256
|
+
|
257
|
+
logit = self.causal_lm(
|
258
|
+
input_ids=input_ids,
|
259
|
+
inputs_embeds=inputs_embeds,
|
260
|
+
attention_mask=attention_mask,
|
261
|
+
cache_position=cache_position,
|
262
|
+
query_position=query_position,
|
263
|
+
past_key_values=past_key_values,
|
264
|
+
rotary_emb=rotary_emb,
|
265
|
+
block_tables=block_tables,
|
266
|
+
)
|
267
|
+
|
268
|
+
return logit
|
269
|
+
|
216
270
|
def forward(self, *args):
|
217
271
|
if self.phase == "decode":
|
218
272
|
if self.use_attention_mask:
|
@@ -255,43 +309,16 @@ class DecoderOnlyWrapper(nn.Module):
|
|
255
309
|
else:
|
256
310
|
raise ValueError(f"Unknown phase: {self.phase}")
|
257
311
|
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
if len(past_key_values) != 2 * self.num_hidden_layers:
|
268
|
-
raise ValueError(
|
269
|
-
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
270
|
-
)
|
271
|
-
|
272
|
-
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
273
|
-
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
274
|
-
_past_key_values = []
|
275
|
-
for i in range(self.config.num_hidden_layers):
|
276
|
-
key_states = past_key_values[i * 2]
|
277
|
-
value_states = past_key_values[i * 2 + 1]
|
278
|
-
past_key_value = [key_states, value_states]
|
279
|
-
_past_key_values.append(past_key_value)
|
280
|
-
past_key_values = _past_key_values
|
281
|
-
|
282
|
-
logit = self.causal_lm(
|
283
|
-
input_ids=input_ids,
|
284
|
-
inputs_embeds=inputs_embeds,
|
285
|
-
attention_mask=attention_mask,
|
286
|
-
cache_position=cache_position,
|
287
|
-
query_position=query_position,
|
288
|
-
past_key_values=past_key_values,
|
289
|
-
rotary_emb=self.rotary_emb,
|
290
|
-
block_tables=block_tables,
|
312
|
+
return self.forward_common(
|
313
|
+
input_ids_or_inputs_embeds,
|
314
|
+
cache_position,
|
315
|
+
attention_mask,
|
316
|
+
query_position,
|
317
|
+
block_tables,
|
318
|
+
self.rotary_emb,
|
319
|
+
*past_key_values,
|
291
320
|
)
|
292
321
|
|
293
|
-
return logit
|
294
|
-
|
295
322
|
|
296
323
|
class DecoderOnlyForCausalLM(nn.Module):
|
297
324
|
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
@@ -315,12 +342,13 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
315
342
|
_phase: Current processing phase ("prefill" or "decode")
|
316
343
|
"""
|
317
344
|
|
318
|
-
def __init__(self, causal_lm: PreTrainedModel, model):
|
345
|
+
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
319
346
|
super().__init__()
|
320
347
|
self.config = causal_lm.config
|
321
348
|
self._original_mod = causal_lm
|
322
349
|
self.model = model
|
323
350
|
self._phase = "prefill"
|
351
|
+
self.lm_head = self._original_mod.lm_head
|
324
352
|
|
325
353
|
@property
|
326
354
|
def phase(self):
|
@@ -356,7 +384,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
356
384
|
if self.phase == "prefill":
|
357
385
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
358
386
|
|
359
|
-
logits = self.
|
387
|
+
logits = self.lm_head(hidden_states)
|
360
388
|
return logits
|
361
389
|
|
362
390
|
|
@@ -448,8 +476,12 @@ class DecoderOnlyModel(nn.Module):
|
|
448
476
|
|
449
477
|
# get cos,sin vector if needed
|
450
478
|
if rotary_emb is not None:
|
451
|
-
|
452
|
-
|
479
|
+
if isinstance(rotary_emb, torch.Tensor):
|
480
|
+
cos = rotary_emb[0]
|
481
|
+
sin = rotary_emb[1]
|
482
|
+
else:
|
483
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
484
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
453
485
|
else:
|
454
486
|
batch_size = inputs_embeds.shape[0]
|
455
487
|
if cache_position.shape[0] > 1:
|
@@ -826,7 +858,6 @@ def rotate_half(x):
|
|
826
858
|
|
827
859
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
828
860
|
"""Applies Rotary Position Embedding to the query and key tensors."""
|
829
|
-
|
830
861
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
831
862
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
832
863
|
return q_embed, k_embed
|