optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +8 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +4 -4
- optimum/rbln/diffusers/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +2 -2
- optimum/rbln/modeling_base.py +12 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +10 -16
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
- optimum/rbln/utils/runtime_utils.py +3 -3
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
|
|
46
46
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
47
47
|
mandatory_members = ["main_input_name"]
|
|
48
48
|
|
|
49
|
-
def forward(self, *args: List[torch.Tensor], **kwargs:
|
|
49
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
50
50
|
output = super().forward(*args, **kwargs)
|
|
51
51
|
return BaseModelOutput(last_hidden_state=output)
|
|
52
52
|
|
|
@@ -253,6 +253,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
253
253
|
|
|
254
254
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
255
255
|
|
|
256
|
+
@classmethod
|
|
257
|
+
def _update_paged_attention_config(
|
|
258
|
+
cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
|
|
259
|
+
):
|
|
260
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
261
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
262
|
+
|
|
263
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
264
|
+
raise NotImplementedError(
|
|
265
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
269
|
+
raise NotImplementedError(
|
|
270
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
271
|
+
)
|
|
272
|
+
|
|
256
273
|
@classmethod
|
|
257
274
|
def _update_rbln_config(
|
|
258
275
|
cls,
|
|
@@ -270,6 +287,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
270
287
|
if rbln_config.dec_max_seq_len is None:
|
|
271
288
|
rbln_config.dec_max_seq_len = model_config.max_length
|
|
272
289
|
|
|
290
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
291
|
+
|
|
273
292
|
enc_input_info = [
|
|
274
293
|
("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
|
|
275
294
|
("block_tables", [1], "int16"),
|
|
@@ -12,14 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_xlm_roberta import
|
|
16
|
-
|
|
17
|
-
RBLNXLMRobertaModelConfig,
|
|
18
|
-
)
|
|
19
|
-
from .modeling_xlm_roberta import (
|
|
20
|
-
RBLNXLMRobertaForSequenceClassification,
|
|
21
|
-
RBLNXLMRobertaModel,
|
|
22
|
-
)
|
|
15
|
+
from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
|
|
16
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
|
|
23
17
|
|
|
24
18
|
|
|
25
19
|
__all__ = [
|
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import glob
|
|
16
|
+
import json
|
|
16
17
|
import os
|
|
17
18
|
from typing import Any, Dict, Optional, Union
|
|
18
19
|
|
|
19
20
|
import torch
|
|
21
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
20
22
|
from safetensors.torch import load_file
|
|
21
23
|
from torch.nn import Linear, Parameter
|
|
22
24
|
from torch.nn import functional as F
|
|
@@ -30,21 +32,24 @@ logger = get_logger()
|
|
|
30
32
|
|
|
31
33
|
class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
32
34
|
SUPPORTED_FORMATS = ["rbln"]
|
|
33
|
-
SUPPORTED_WEIGHTS = ["int4", "fp16"]
|
|
34
|
-
SUPPORTED_ACTIVATIONS = ["fp16"]
|
|
35
|
-
|
|
36
|
-
# The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
|
|
37
|
-
# It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
|
|
35
|
+
SUPPORTED_WEIGHTS = ["int4", "fp8", "fp16"]
|
|
36
|
+
SUPPORTED_ACTIVATIONS = ["fp8", "fp16"]
|
|
37
|
+
SUPPORTED_KVCACHES = ["fp8", "fp16"]
|
|
38
38
|
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
|
39
39
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
format: Optional[str] = None,
|
|
43
|
-
precision: Optional[str] = None,
|
|
44
43
|
weights: Optional[str] = None,
|
|
45
44
|
activations: Optional[str] = None,
|
|
45
|
+
kv_caches: Optional[str] = None,
|
|
46
|
+
*,
|
|
47
|
+
precision: Optional[str] = None,
|
|
46
48
|
):
|
|
47
|
-
self.format = format
|
|
49
|
+
self.format = format or "rbln"
|
|
50
|
+
if self.format not in self.SUPPORTED_FORMATS:
|
|
51
|
+
raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
|
|
52
|
+
|
|
48
53
|
if precision is not None:
|
|
49
54
|
logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
|
|
50
55
|
if any(precision_arg is not None for precision_arg in (weights, activations)):
|
|
@@ -58,6 +63,8 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
58
63
|
|
|
59
64
|
self.weights = weights or "fp16"
|
|
60
65
|
self.activations = activations or "fp16"
|
|
66
|
+
self.kv_caches = kv_caches or "fp16"
|
|
67
|
+
|
|
61
68
|
self._validate()
|
|
62
69
|
|
|
63
70
|
def _validate(self):
|
|
@@ -69,27 +76,49 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
69
76
|
raise ValueError(
|
|
70
77
|
f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
|
|
71
78
|
)
|
|
79
|
+
if self.kv_caches not in self.SUPPORTED_KVCACHES:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
|
|
82
|
+
)
|
|
72
83
|
if self.weights == "fp16" and self.activations == "fp16":
|
|
73
|
-
raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
|
|
84
|
+
raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
|
|
74
85
|
|
|
75
86
|
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
|
76
87
|
return {
|
|
77
88
|
"format": self.format,
|
|
78
89
|
"weights": self.weights,
|
|
79
90
|
"activations": self.activations,
|
|
91
|
+
"kv_caches": self.kv_caches,
|
|
80
92
|
}
|
|
81
93
|
|
|
82
94
|
def maybe_set_quantization_env(self):
|
|
83
|
-
quant_bits = None
|
|
84
95
|
if self.weights == "int4":
|
|
85
|
-
|
|
86
|
-
os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
|
|
96
|
+
os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
|
|
87
97
|
|
|
88
98
|
def maybe_reset_quantization_env(self):
|
|
89
99
|
if self.RBLN_QUANT_BITS_ENV in os.environ:
|
|
90
100
|
os.environ.pop(self.RBLN_QUANT_BITS_ENV)
|
|
91
101
|
|
|
92
102
|
|
|
103
|
+
class QuantizedLayerFactory:
|
|
104
|
+
def __init__(self, quantization_config: RBLNQuantizationConfig):
|
|
105
|
+
self.quantization_config = quantization_config
|
|
106
|
+
|
|
107
|
+
def create_linear(self, layer: Linear) -> Linear:
|
|
108
|
+
if self.quantization_config.weights == "int4":
|
|
109
|
+
return self.create_qlinear(layer)
|
|
110
|
+
elif self.quantization_config.weights == "fp8":
|
|
111
|
+
return self.create_fp8linear(layer)
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
|
|
114
|
+
|
|
115
|
+
def create_qlinear(self, layer: Linear) -> Linear:
|
|
116
|
+
return create_qlinear(layer, self.quantization_config)
|
|
117
|
+
|
|
118
|
+
def create_fp8linear(self, layer: Linear) -> Linear:
|
|
119
|
+
return create_fp8linear(layer, self.quantization_config)
|
|
120
|
+
|
|
121
|
+
|
|
93
122
|
# Constants
|
|
94
123
|
QUANTIZED_WEIGHTS = {
|
|
95
124
|
"q_proj",
|
|
@@ -111,64 +140,60 @@ def prepare_model_for_quantization(
|
|
|
111
140
|
cache_dir: Optional[str] = None,
|
|
112
141
|
force_download: bool = False,
|
|
113
142
|
local_files_only: bool = False,
|
|
143
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
114
144
|
) -> torch.nn.Module:
|
|
115
145
|
"""
|
|
116
146
|
Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
|
|
117
147
|
"""
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
148
|
+
|
|
149
|
+
# 1. Load weight files and safetensors.index.json
|
|
150
|
+
safetensor_files, index_data = load_weight_files_and_index(
|
|
121
151
|
model_id,
|
|
122
|
-
n_layer,
|
|
123
152
|
use_auth_token=use_auth_token,
|
|
124
153
|
revision=revision,
|
|
125
154
|
cache_dir=cache_dir,
|
|
126
155
|
force_download=force_download,
|
|
127
156
|
local_files_only=local_files_only,
|
|
128
157
|
)
|
|
129
|
-
return model
|
|
130
158
|
|
|
159
|
+
# 2. Determine format from safetensors.index.json
|
|
160
|
+
determined_format = determine_format_from_index(index_data)
|
|
131
161
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
135
|
-
"""
|
|
136
|
-
|
|
137
|
-
logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
|
|
138
|
-
processed_layers = []
|
|
162
|
+
# 3. Update linear layers based on the determined format
|
|
163
|
+
update_layers_to_quantize(model, rbln_quantization)
|
|
139
164
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
165
|
+
# 4. Load weights into model parameters
|
|
166
|
+
load_weights_from_files(
|
|
167
|
+
model,
|
|
168
|
+
safetensor_files,
|
|
169
|
+
n_layer,
|
|
170
|
+
rbln_quantization=rbln_quantization,
|
|
171
|
+
determined_format=determined_format,
|
|
172
|
+
)
|
|
145
173
|
|
|
146
|
-
|
|
147
|
-
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
|
174
|
+
return model
|
|
148
175
|
|
|
149
176
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
local_files_only=False,
|
|
159
|
-
):
|
|
177
|
+
def load_weight_files_and_index(
|
|
178
|
+
model_id: str,
|
|
179
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
180
|
+
revision: Optional[str] = None,
|
|
181
|
+
cache_dir: Optional[str] = None,
|
|
182
|
+
force_download: bool = False,
|
|
183
|
+
local_files_only: bool = False,
|
|
184
|
+
) -> tuple[list[str], Optional[Dict]]:
|
|
160
185
|
"""
|
|
161
186
|
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
|
162
187
|
"""
|
|
163
|
-
|
|
164
|
-
model_params = dict(model.named_parameters(recurse=True))
|
|
165
|
-
model_buffers = dict(model.named_buffers(recurse=True))
|
|
188
|
+
index_data = None
|
|
166
189
|
|
|
167
190
|
if os.path.isdir(model_id):
|
|
168
191
|
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
|
192
|
+
index_path = os.path.join(model_id, "model.safetensors.index.json")
|
|
193
|
+
if os.path.exists(index_path):
|
|
194
|
+
with open(index_path, "r") as f:
|
|
195
|
+
index_data = json.load(f)
|
|
169
196
|
else:
|
|
170
|
-
from huggingface_hub import hf_hub_download, list_repo_files
|
|
171
|
-
|
|
172
197
|
try:
|
|
173
198
|
# List all files in the repository
|
|
174
199
|
repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
|
|
@@ -188,6 +213,20 @@ def load_weights(
|
|
|
188
213
|
local_files_only=local_files_only,
|
|
189
214
|
)
|
|
190
215
|
safetensor_files.append(downloaded_file)
|
|
216
|
+
elif file == "model.safetensors.index.json":
|
|
217
|
+
# Download the index file
|
|
218
|
+
index_file = hf_hub_download(
|
|
219
|
+
repo_id=model_id,
|
|
220
|
+
filename=file,
|
|
221
|
+
revision=revision,
|
|
222
|
+
token=use_auth_token,
|
|
223
|
+
cache_dir=cache_dir,
|
|
224
|
+
force_download=force_download,
|
|
225
|
+
local_files_only=local_files_only,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
with open(index_file, "r") as f:
|
|
229
|
+
index_data = json.load(f)
|
|
191
230
|
except Exception as e:
|
|
192
231
|
logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
|
|
193
232
|
raise e
|
|
@@ -195,12 +234,85 @@ def load_weights(
|
|
|
195
234
|
if not safetensor_files:
|
|
196
235
|
raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
|
|
197
236
|
|
|
237
|
+
return safetensor_files, index_data
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def determine_format_from_index(index_data: Optional[Dict]) -> str:
|
|
241
|
+
"""
|
|
242
|
+
Determine the quantization format from safetensors.index.json data.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
index_data: The loaded safetensors.index.json content
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
str: The determined format string
|
|
249
|
+
"""
|
|
250
|
+
if index_data is None:
|
|
251
|
+
raise ValueError("safetensors.index.json not found")
|
|
252
|
+
if "weight_map" not in index_data:
|
|
253
|
+
raise ValueError("weight_map not found in safetensors.index.json")
|
|
254
|
+
|
|
255
|
+
if any("self_attn.k_proj.k_scale" in key for key in index_data["weight_map"]):
|
|
256
|
+
return "tensorrt"
|
|
257
|
+
elif any("self_attn.kv_scale" in key for key in index_data["weight_map"]):
|
|
258
|
+
return "quark"
|
|
259
|
+
elif any("weight_scale" in key or "input_scale" in key for key in index_data["weight_map"]):
|
|
260
|
+
return "default"
|
|
261
|
+
else:
|
|
262
|
+
raise ValueError("Unknown quantization format of the index data of weight map.")
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def update_layers_to_quantize(
|
|
266
|
+
module: torch.nn.Module,
|
|
267
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
268
|
+
) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
processed_layers = []
|
|
274
|
+
quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
|
|
275
|
+
|
|
276
|
+
for name, layer in module.named_modules():
|
|
277
|
+
if is_target_for_qlinear_replacement(name, layer):
|
|
278
|
+
parent_module, layer_name = get_parent_and_child(module, name)
|
|
279
|
+
setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
|
|
280
|
+
processed_layers.append(name)
|
|
281
|
+
|
|
282
|
+
if processed_layers:
|
|
283
|
+
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def load_weights_from_files(
|
|
287
|
+
model: torch.nn.Module,
|
|
288
|
+
safetensor_files: list[str],
|
|
289
|
+
n_layer: Optional[int] = None,
|
|
290
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
291
|
+
determined_format: Optional[str] = None,
|
|
292
|
+
):
|
|
293
|
+
"""
|
|
294
|
+
Load safetensor file data directly into the model from provided safetensor files,
|
|
295
|
+
filtering by layer if n_layer is provided.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
model_params = dict(model.named_parameters(recurse=True))
|
|
299
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
|
300
|
+
|
|
198
301
|
target_layers = list(range(n_layer)) if n_layer is not None else None
|
|
199
302
|
|
|
200
303
|
unloaded_keys = []
|
|
304
|
+
loaded_input_scale = False
|
|
305
|
+
loaded_kv_scale = False
|
|
306
|
+
loaded_weight_scale = False
|
|
307
|
+
|
|
201
308
|
for safetensor_file in safetensor_files:
|
|
202
309
|
file_data = load_file(safetensor_file)
|
|
310
|
+
|
|
203
311
|
for key, value in file_data.items():
|
|
312
|
+
loaded_input_scale = loaded_input_scale or "input_scale" in key
|
|
313
|
+
loaded_weight_scale = loaded_weight_scale or "weight_scale" in key
|
|
314
|
+
loaded_kv_scale = loaded_kv_scale or any(scale in key for scale in ["kv_scale", "k_scale", "v_scale"])
|
|
315
|
+
|
|
204
316
|
if target_layers is not None:
|
|
205
317
|
parts = key.split(".")
|
|
206
318
|
|
|
@@ -211,12 +323,38 @@ def load_weights(
|
|
|
211
323
|
model_params[key].data.copy_(value)
|
|
212
324
|
elif key in model_buffers:
|
|
213
325
|
model_buffers[key].data.copy_(value)
|
|
326
|
+
elif "kv_scale" in key and determined_format == "quark":
|
|
327
|
+
if rbln_quantization.kv_caches == "fp8":
|
|
328
|
+
model_params[key.replace("kv_scale", "k_proj.k_scale")].data.copy_(value)
|
|
329
|
+
model_params[key.replace("kv_scale", "v_proj.v_scale")].data.copy_(value)
|
|
330
|
+
else:
|
|
331
|
+
unloaded_keys.append(key)
|
|
214
332
|
else:
|
|
215
333
|
unloaded_keys.append(key)
|
|
216
334
|
|
|
217
335
|
if len(unloaded_keys) > 0:
|
|
218
336
|
logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
|
|
219
337
|
|
|
338
|
+
if not loaded_input_scale and rbln_quantization.activations == "fp8":
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"No input_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
341
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
342
|
+
)
|
|
343
|
+
if not loaded_weight_scale and rbln_quantization.weights == "fp8":
|
|
344
|
+
raise ValueError(
|
|
345
|
+
"No weight_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
346
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
347
|
+
)
|
|
348
|
+
if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
|
|
349
|
+
raise ValueError(
|
|
350
|
+
"No kv_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
351
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
352
|
+
)
|
|
353
|
+
if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
|
|
354
|
+
logger.warning(
|
|
355
|
+
"kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
|
|
356
|
+
)
|
|
357
|
+
|
|
220
358
|
|
|
221
359
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
|
222
360
|
"""
|
|
@@ -225,6 +363,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
|
|
|
225
363
|
return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
|
|
226
364
|
|
|
227
365
|
|
|
366
|
+
def is_target_for_adding_kv_scales(layer_name: str) -> bool:
|
|
367
|
+
return layer_name.split(".")[-1] in ["self_attn"]
|
|
368
|
+
|
|
369
|
+
|
|
228
370
|
def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
|
|
229
371
|
"""
|
|
230
372
|
Splits the full layer name to retrieve the parent module and the child layer.
|
|
@@ -243,7 +385,7 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
|
|
|
243
385
|
return obj
|
|
244
386
|
|
|
245
387
|
|
|
246
|
-
def create_qlinear(layer: Linear) -> Linear:
|
|
388
|
+
def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
247
389
|
"""
|
|
248
390
|
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
|
|
249
391
|
"""
|
|
@@ -262,3 +404,64 @@ def create_qlinear(layer: Linear) -> Linear:
|
|
|
262
404
|
layer.forward = lambda inputs: qlinear_forward(layer, inputs)
|
|
263
405
|
|
|
264
406
|
return layer
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
410
|
+
"""
|
|
411
|
+
Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
|
|
415
|
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
416
|
+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
417
|
+
return qweight
|
|
418
|
+
|
|
419
|
+
def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
|
|
420
|
+
A = A.type(out_dtype)
|
|
421
|
+
B = B.type(out_dtype)
|
|
422
|
+
|
|
423
|
+
if A_scale is not None:
|
|
424
|
+
A *= A_scale
|
|
425
|
+
if B_scale is not None:
|
|
426
|
+
B *= B_scale.to(out_dtype)
|
|
427
|
+
|
|
428
|
+
output = torch.nn.functional.linear(A, B, bias=bias)
|
|
429
|
+
return output
|
|
430
|
+
|
|
431
|
+
def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
432
|
+
if self.input_scale:
|
|
433
|
+
input = static_per_tensor_quantize(x, self.input_scale)
|
|
434
|
+
else:
|
|
435
|
+
input = x
|
|
436
|
+
|
|
437
|
+
if self.weight_scale:
|
|
438
|
+
# broadcast weight_scale to vector
|
|
439
|
+
weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
|
|
440
|
+
else:
|
|
441
|
+
weight_scale = None
|
|
442
|
+
output = fp8_gemm(
|
|
443
|
+
A=input,
|
|
444
|
+
A_scale=self.input_scale,
|
|
445
|
+
B=self.weight,
|
|
446
|
+
B_scale=weight_scale,
|
|
447
|
+
bias=self.bias,
|
|
448
|
+
out_dtype=x.dtype,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
return output
|
|
452
|
+
|
|
453
|
+
layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
|
|
454
|
+
layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
455
|
+
|
|
456
|
+
if rbln_quantization.activations == "fp8":
|
|
457
|
+
layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
458
|
+
else:
|
|
459
|
+
layer.input_scale = None
|
|
460
|
+
|
|
461
|
+
if rbln_quantization.kv_caches == "fp8":
|
|
462
|
+
layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
463
|
+
layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
464
|
+
|
|
465
|
+
layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
|
|
466
|
+
|
|
467
|
+
return layer
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
import threading
|
|
17
|
-
from typing import Any,
|
|
17
|
+
from typing import Any, List, Optional, Union
|
|
18
18
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
@@ -94,7 +94,7 @@ class RBLNPytorchRuntime:
|
|
|
94
94
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
95
95
|
return self.forward(*args, **kwds)
|
|
96
96
|
|
|
97
|
-
def forward(self, *args: List["torch.Tensor"], **kwargs:
|
|
97
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
|
|
98
98
|
# filtering useless args or kwarg such as None.
|
|
99
99
|
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
|
100
100
|
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
|
|
@@ -142,7 +142,7 @@ class UnavailableRuntime:
|
|
|
142
142
|
"""Returns an iterator with self as the only item."""
|
|
143
143
|
return iter([self])
|
|
144
144
|
|
|
145
|
-
def forward(self, *args: List["torch.Tensor"], **kwargs:
|
|
145
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
|
|
146
146
|
"""Raises a detailed RuntimeError explaining why inference cannot be performed."""
|
|
147
147
|
raise RuntimeError(
|
|
148
148
|
"Cannot perform inference: RBLN runtime is not available.\n\n"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: optimum-rbln
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.3a0
|
|
4
4
|
Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|