optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +36 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +20 -4
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +3 -2
- optimum/rbln/modeling_base.py +29 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +28 -0
- optimum/rbln/transformers/configuration_generic.py +6 -4
- optimum/rbln/transformers/modeling_generic.py +13 -8
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +35 -16
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +64 -258
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
- optimum/rbln/utils/runtime_utils.py +3 -3
- optimum/rbln/utils/submodule.py +10 -4
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,9 +14,10 @@
|
|
|
14
14
|
|
|
15
15
|
import glob
|
|
16
16
|
import os
|
|
17
|
-
from typing import Any, Dict, Optional, Union
|
|
17
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
20
21
|
from safetensors.torch import load_file
|
|
21
22
|
from torch.nn import Linear, Parameter
|
|
22
23
|
from torch.nn import functional as F
|
|
@@ -28,23 +29,47 @@ from ...utils.logging import get_logger
|
|
|
28
29
|
logger = get_logger()
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
# Constants
|
|
33
|
+
QUANTIZED_WEIGHTS = {
|
|
34
|
+
"q_proj",
|
|
35
|
+
"k_proj",
|
|
36
|
+
"v_proj",
|
|
37
|
+
"o_proj",
|
|
38
|
+
"gate_proj",
|
|
39
|
+
"up_proj",
|
|
40
|
+
"down_proj",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Common alias sets seen in community checkpoints
|
|
44
|
+
VARIANT_ALIASES: Dict[str, List[str]] = {
|
|
45
|
+
"weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
|
|
46
|
+
"input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
|
|
47
|
+
"kv_scale": ["kv_scale", "kv_scales"],
|
|
48
|
+
"k_scale": ["k_scale", "k_scales"],
|
|
49
|
+
"v_scale": ["v_scale", "v_scales"],
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
31
53
|
class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
32
54
|
SUPPORTED_FORMATS = ["rbln"]
|
|
33
|
-
SUPPORTED_WEIGHTS = ["int4", "fp16"]
|
|
34
|
-
SUPPORTED_ACTIVATIONS = ["fp16"]
|
|
35
|
-
|
|
36
|
-
# The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
|
|
37
|
-
# It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
|
|
55
|
+
SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
|
|
56
|
+
SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
|
|
57
|
+
SUPPORTED_KVCACHES = ["fp8", "fp16"]
|
|
38
58
|
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
|
39
59
|
|
|
40
60
|
def __init__(
|
|
41
61
|
self,
|
|
42
62
|
format: Optional[str] = None,
|
|
43
|
-
precision: Optional[str] = None,
|
|
44
63
|
weights: Optional[str] = None,
|
|
45
64
|
activations: Optional[str] = None,
|
|
65
|
+
kv_caches: Optional[str] = None,
|
|
66
|
+
*,
|
|
67
|
+
precision: Optional[str] = None,
|
|
46
68
|
):
|
|
47
|
-
self.format = format
|
|
69
|
+
self.format = format or "rbln"
|
|
70
|
+
if self.format not in self.SUPPORTED_FORMATS:
|
|
71
|
+
raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
|
|
72
|
+
|
|
48
73
|
if precision is not None:
|
|
49
74
|
logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
|
|
50
75
|
if any(precision_arg is not None for precision_arg in (weights, activations)):
|
|
@@ -58,6 +83,7 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
58
83
|
|
|
59
84
|
self.weights = weights or "fp16"
|
|
60
85
|
self.activations = activations or "fp16"
|
|
86
|
+
self.kv_caches = kv_caches or "fp16"
|
|
61
87
|
self._validate()
|
|
62
88
|
|
|
63
89
|
def _validate(self):
|
|
@@ -69,37 +95,47 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
69
95
|
raise ValueError(
|
|
70
96
|
f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
|
|
71
97
|
)
|
|
98
|
+
if self.kv_caches not in self.SUPPORTED_KVCACHES:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
|
|
101
|
+
)
|
|
72
102
|
if self.weights == "fp16" and self.activations == "fp16":
|
|
73
|
-
raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
|
|
103
|
+
raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
|
|
74
104
|
|
|
75
105
|
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
|
76
106
|
return {
|
|
77
107
|
"format": self.format,
|
|
78
108
|
"weights": self.weights,
|
|
79
109
|
"activations": self.activations,
|
|
110
|
+
"kv_caches": self.kv_caches,
|
|
80
111
|
}
|
|
81
112
|
|
|
82
113
|
def maybe_set_quantization_env(self):
|
|
83
|
-
quant_bits = None
|
|
84
114
|
if self.weights == "int4":
|
|
85
|
-
|
|
86
|
-
os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
|
|
115
|
+
os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
|
|
87
116
|
|
|
88
117
|
def maybe_reset_quantization_env(self):
|
|
89
118
|
if self.RBLN_QUANT_BITS_ENV in os.environ:
|
|
90
119
|
os.environ.pop(self.RBLN_QUANT_BITS_ENV)
|
|
91
120
|
|
|
92
121
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
122
|
+
class QuantizedLayerFactory:
|
|
123
|
+
def __init__(self, quantization_config: RBLNQuantizationConfig):
|
|
124
|
+
self.quantization_config = quantization_config
|
|
125
|
+
|
|
126
|
+
def create_linear(self, layer: Linear) -> Linear:
|
|
127
|
+
if self.quantization_config.weights in ["int4", "int8"]:
|
|
128
|
+
return self.create_qlinear(layer)
|
|
129
|
+
elif self.quantization_config.weights == "fp8":
|
|
130
|
+
return self.create_fp8linear(layer)
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
|
|
133
|
+
|
|
134
|
+
def create_qlinear(self, layer: Linear) -> Linear:
|
|
135
|
+
return create_qlinear(layer, self.quantization_config)
|
|
136
|
+
|
|
137
|
+
def create_fp8linear(self, layer: Linear) -> Linear:
|
|
138
|
+
return create_fp8linear(layer, self.quantization_config)
|
|
103
139
|
|
|
104
140
|
|
|
105
141
|
def prepare_model_for_quantization(
|
|
@@ -111,64 +147,51 @@ def prepare_model_for_quantization(
|
|
|
111
147
|
cache_dir: Optional[str] = None,
|
|
112
148
|
force_download: bool = False,
|
|
113
149
|
local_files_only: bool = False,
|
|
150
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
114
151
|
) -> torch.nn.Module:
|
|
115
152
|
"""
|
|
116
153
|
Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
|
|
117
154
|
"""
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
155
|
+
|
|
156
|
+
# 1. Load weight files
|
|
157
|
+
safetensor_files = load_weight_files(
|
|
121
158
|
model_id,
|
|
122
|
-
n_layer,
|
|
123
159
|
use_auth_token=use_auth_token,
|
|
124
160
|
revision=revision,
|
|
125
161
|
cache_dir=cache_dir,
|
|
126
162
|
force_download=force_download,
|
|
127
163
|
local_files_only=local_files_only,
|
|
128
164
|
)
|
|
129
|
-
return model
|
|
130
|
-
|
|
131
165
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
135
|
-
"""
|
|
136
|
-
|
|
137
|
-
logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
|
|
138
|
-
processed_layers = []
|
|
166
|
+
# 2. Update linear layers based on the quantization config
|
|
167
|
+
update_layers_to_quantize(model, rbln_quantization)
|
|
139
168
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
169
|
+
# 3. Load weights into model parameters
|
|
170
|
+
load_weights_from_files(
|
|
171
|
+
model,
|
|
172
|
+
safetensor_files,
|
|
173
|
+
n_layer,
|
|
174
|
+
rbln_quantization=rbln_quantization,
|
|
175
|
+
)
|
|
145
176
|
|
|
146
|
-
|
|
147
|
-
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
|
177
|
+
return model
|
|
148
178
|
|
|
149
179
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
local_files_only=False,
|
|
159
|
-
):
|
|
180
|
+
def load_weight_files(
|
|
181
|
+
model_id: str,
|
|
182
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
183
|
+
revision: Optional[str] = None,
|
|
184
|
+
cache_dir: Optional[str] = None,
|
|
185
|
+
force_download: bool = False,
|
|
186
|
+
local_files_only: bool = False,
|
|
187
|
+
) -> list[str]:
|
|
160
188
|
"""
|
|
161
|
-
|
|
189
|
+
Discover and download safetensors files for the given model id.
|
|
162
190
|
"""
|
|
163
191
|
|
|
164
|
-
model_params = dict(model.named_parameters(recurse=True))
|
|
165
|
-
model_buffers = dict(model.named_buffers(recurse=True))
|
|
166
|
-
|
|
167
192
|
if os.path.isdir(model_id):
|
|
168
193
|
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
|
169
194
|
else:
|
|
170
|
-
from huggingface_hub import hf_hub_download, list_repo_files
|
|
171
|
-
|
|
172
195
|
try:
|
|
173
196
|
# List all files in the repository
|
|
174
197
|
repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
|
|
@@ -195,27 +218,238 @@ def load_weights(
|
|
|
195
218
|
if not safetensor_files:
|
|
196
219
|
raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
|
|
197
220
|
|
|
221
|
+
return safetensor_files
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def update_layers_to_quantize(
|
|
225
|
+
module: torch.nn.Module,
|
|
226
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
"""
|
|
229
|
+
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
processed_layers = []
|
|
233
|
+
quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
|
|
234
|
+
|
|
235
|
+
for name, layer in module.named_modules():
|
|
236
|
+
if is_target_for_qlinear_replacement(name, layer):
|
|
237
|
+
parent_module, layer_name = get_parent_and_child(module, name)
|
|
238
|
+
setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
|
|
239
|
+
processed_layers.append(name)
|
|
240
|
+
|
|
241
|
+
if processed_layers:
|
|
242
|
+
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _last_segment(key: str) -> str:
|
|
246
|
+
parts = key.split(".")
|
|
247
|
+
return parts[-1]
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _replace_last_with(key: str, new_tail: str) -> str:
|
|
251
|
+
parts = key.split(".")
|
|
252
|
+
return ".".join(parts[:-1] + new_tail.split("."))
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _matches_any_alias(key: str, kind: str) -> bool:
|
|
256
|
+
tail = _last_segment(key)
|
|
257
|
+
return tail in VARIANT_ALIASES.get(kind, [])
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
|
|
261
|
+
if t.ndim == 0:
|
|
262
|
+
return t
|
|
263
|
+
return t.reshape(-1).amax()
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
|
|
267
|
+
s = scale
|
|
268
|
+
if s.ndim == 0:
|
|
269
|
+
# scalar -> expand to [out_features, 1]
|
|
270
|
+
return s.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
271
|
+
if s.ndim == 1:
|
|
272
|
+
if s.numel() == 1:
|
|
273
|
+
return s.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
274
|
+
if s.numel() == out_features:
|
|
275
|
+
return s.reshape(out_features, 1).contiguous()
|
|
276
|
+
# fallback: reduce to scalar then expand
|
|
277
|
+
v = _reduce_to_scalar(s)
|
|
278
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
279
|
+
if s.ndim == 2:
|
|
280
|
+
if s.shape == (out_features, 1):
|
|
281
|
+
return s.contiguous()
|
|
282
|
+
if s.shape == (1, out_features):
|
|
283
|
+
return s.transpose(0, 1).contiguous()
|
|
284
|
+
# fallback: reduce to [out_features] on non-out dims if possible
|
|
285
|
+
if s.shape[0] == out_features:
|
|
286
|
+
v = s
|
|
287
|
+
while v.ndim > 2:
|
|
288
|
+
v = v.amax(dim=-1)
|
|
289
|
+
if v.shape[-1] != 1:
|
|
290
|
+
v = v.amax(dim=-1, keepdim=True)
|
|
291
|
+
return v.contiguous()
|
|
292
|
+
# otherwise reduce to scalar then expand
|
|
293
|
+
v = _reduce_to_scalar(s)
|
|
294
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
295
|
+
# high-rank: reduce to scalar then expand
|
|
296
|
+
v = _reduce_to_scalar(s)
|
|
297
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
|
|
301
|
+
# base_key is the original key whose last token was 'kv_scale'
|
|
302
|
+
# We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
|
|
303
|
+
if tensor.ndim == 1 and tensor.numel() >= 2:
|
|
304
|
+
tk, tv = tensor[0], tensor[1]
|
|
305
|
+
elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
|
|
306
|
+
tk, tv = tensor[0, 0], tensor[1, 0]
|
|
307
|
+
else:
|
|
308
|
+
tk = tv = tensor
|
|
309
|
+
k_key = _replace_last_with(base_key, "k_proj.k_scale")
|
|
310
|
+
v_key = _replace_last_with(base_key, "v_proj.v_scale")
|
|
311
|
+
return [(k_key, tk), (v_key, tv)]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def canonicalize_checkpoint_items(
|
|
315
|
+
model: torch.nn.Module,
|
|
316
|
+
items: Iterable[Tuple[str, torch.Tensor]],
|
|
317
|
+
rbln_quantization: Optional[RBLNQuantizationConfig],
|
|
318
|
+
) -> List[Tuple[str, torch.Tensor]]:
|
|
319
|
+
params = dict(model.named_parameters(recurse=True))
|
|
320
|
+
results: List[Tuple[str, torch.Tensor]] = []
|
|
321
|
+
|
|
322
|
+
for key, value in items:
|
|
323
|
+
t = value
|
|
324
|
+
# Normalize weight scale variants
|
|
325
|
+
if _matches_any_alias(key, "weight_scale"):
|
|
326
|
+
# rename last token to the canonical weight scale key
|
|
327
|
+
target_key = _replace_last_with(key, "weight_scale")
|
|
328
|
+
|
|
329
|
+
# Determine associated weight param to infer shape
|
|
330
|
+
weight_key = _replace_last_with(target_key, "weight")
|
|
331
|
+
out_features = None
|
|
332
|
+
if weight_key in params:
|
|
333
|
+
wshape = params[weight_key].shape
|
|
334
|
+
if len(wshape) == 2:
|
|
335
|
+
out_features = int(wshape[0])
|
|
336
|
+
|
|
337
|
+
if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
|
|
338
|
+
t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
|
|
339
|
+
elif rbln_quantization.weights == "fp8":
|
|
340
|
+
# Use a conservative scalar scale to ensure broadcastability
|
|
341
|
+
t = _reduce_to_scalar(t.to(torch.float32))
|
|
342
|
+
else:
|
|
343
|
+
t = t.to(torch.float32)
|
|
344
|
+
|
|
345
|
+
results.append((target_key, t))
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
# Normalize input/activation scale variants
|
|
349
|
+
if _matches_any_alias(key, "input_scale"):
|
|
350
|
+
target_key = _replace_last_with(key, "input_scale")
|
|
351
|
+
t = _reduce_to_scalar(t.to(torch.float32))
|
|
352
|
+
results.append((target_key, t))
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
# KV scale handling
|
|
356
|
+
if _matches_any_alias(key, "kv_scale"):
|
|
357
|
+
# For quark-like formats, expand to k/v
|
|
358
|
+
kv_items = _kv_split_items(key, t.to(torch.float32))
|
|
359
|
+
for k2, v2 in kv_items:
|
|
360
|
+
results.append((k2, v2))
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
|
|
364
|
+
results.append((key, t.to(torch.float32)))
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
# Default: passthrough
|
|
368
|
+
results.append((key, t))
|
|
369
|
+
|
|
370
|
+
return results
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def load_weights_from_files(
|
|
374
|
+
model: torch.nn.Module,
|
|
375
|
+
safetensor_files: list[str],
|
|
376
|
+
n_layer: Optional[int] = None,
|
|
377
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
378
|
+
):
|
|
379
|
+
"""
|
|
380
|
+
Load safetensor file data directly into the model from provided safetensor files,
|
|
381
|
+
filtering by layer if n_layer is provided.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
model_params = dict(model.named_parameters(recurse=True))
|
|
385
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
|
386
|
+
|
|
198
387
|
target_layers = list(range(n_layer)) if n_layer is not None else None
|
|
199
388
|
|
|
200
389
|
unloaded_keys = []
|
|
390
|
+
loaded_input_scale = False
|
|
391
|
+
loaded_kv_scale = False
|
|
392
|
+
loaded_weight_scale = False
|
|
393
|
+
|
|
201
394
|
for safetensor_file in safetensor_files:
|
|
202
395
|
file_data = load_file(safetensor_file)
|
|
203
|
-
|
|
396
|
+
|
|
397
|
+
# Normalize all (key, tensor) pairs to the internal schema
|
|
398
|
+
normalized_items = canonicalize_checkpoint_items(
|
|
399
|
+
model=model,
|
|
400
|
+
items=file_data.items(),
|
|
401
|
+
rbln_quantization=rbln_quantization,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
for key, value in normalized_items:
|
|
405
|
+
# Track which types of scales were observed (post-normalization)
|
|
406
|
+
if key.endswith("input_scale"):
|
|
407
|
+
loaded_input_scale = True
|
|
408
|
+
if key.endswith("weight_scale"):
|
|
409
|
+
loaded_weight_scale = True
|
|
410
|
+
if key.endswith("k_scale") or key.endswith("v_scale"):
|
|
411
|
+
loaded_kv_scale = True
|
|
412
|
+
|
|
413
|
+
# Filter by layer index if requested
|
|
204
414
|
if target_layers is not None:
|
|
205
415
|
parts = key.split(".")
|
|
206
|
-
|
|
207
416
|
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
|
208
417
|
continue
|
|
209
418
|
|
|
419
|
+
# Copy into parameters or buffers
|
|
210
420
|
if key in model_params:
|
|
421
|
+
# Ensure dtype compatibility
|
|
422
|
+
if model_params[key].dtype != value.dtype:
|
|
423
|
+
value = value.to(model_params[key].dtype)
|
|
211
424
|
model_params[key].data.copy_(value)
|
|
212
425
|
elif key in model_buffers:
|
|
426
|
+
if model_buffers[key].dtype != value.dtype:
|
|
427
|
+
value = value.to(model_buffers[key].dtype)
|
|
213
428
|
model_buffers[key].data.copy_(value)
|
|
214
429
|
else:
|
|
215
430
|
unloaded_keys.append(key)
|
|
216
431
|
|
|
217
432
|
if len(unloaded_keys) > 0:
|
|
218
433
|
logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
|
|
434
|
+
if not loaded_input_scale and rbln_quantization.activations == "fp8":
|
|
435
|
+
raise ValueError(
|
|
436
|
+
"No input_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
437
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
438
|
+
)
|
|
439
|
+
if not loaded_weight_scale and rbln_quantization.weights == "fp8":
|
|
440
|
+
raise ValueError(
|
|
441
|
+
"No weight_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
442
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
443
|
+
)
|
|
444
|
+
if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
|
|
445
|
+
raise ValueError(
|
|
446
|
+
"No kv_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
447
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
448
|
+
)
|
|
449
|
+
if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
|
|
450
|
+
logger.warning(
|
|
451
|
+
"kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
|
|
452
|
+
)
|
|
219
453
|
|
|
220
454
|
|
|
221
455
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
|
@@ -225,6 +459,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
|
|
|
225
459
|
return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
|
|
226
460
|
|
|
227
461
|
|
|
462
|
+
def is_target_for_adding_kv_scales(layer_name: str) -> bool:
|
|
463
|
+
return layer_name.split(".")[-1] in ["self_attn"]
|
|
464
|
+
|
|
465
|
+
|
|
228
466
|
def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
|
|
229
467
|
"""
|
|
230
468
|
Splits the full layer name to retrieve the parent module and the child layer.
|
|
@@ -243,22 +481,84 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
|
|
|
243
481
|
return obj
|
|
244
482
|
|
|
245
483
|
|
|
246
|
-
def create_qlinear(layer: Linear) -> Linear:
|
|
484
|
+
def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
247
485
|
"""
|
|
248
486
|
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
|
|
249
487
|
"""
|
|
250
488
|
|
|
251
489
|
def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
252
|
-
|
|
253
|
-
|
|
490
|
+
weight_scale = self.weight_scale
|
|
491
|
+
if inputs.dtype != weight_scale.dtype:
|
|
492
|
+
raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
|
|
254
493
|
|
|
255
494
|
w_fp = self.weight.type(inputs.dtype)
|
|
256
|
-
w_fp *=
|
|
495
|
+
w_fp *= weight_scale.view(-1, 1)
|
|
257
496
|
return F.linear(inputs, w_fp, self.bias)
|
|
258
497
|
|
|
259
498
|
# Convert weight to int8 and add scale parameter
|
|
260
499
|
layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
|
|
261
|
-
layer.
|
|
500
|
+
layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
|
|
262
501
|
layer.forward = lambda inputs: qlinear_forward(layer, inputs)
|
|
263
502
|
|
|
264
503
|
return layer
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
507
|
+
"""
|
|
508
|
+
Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
|
|
512
|
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
513
|
+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
514
|
+
return qweight
|
|
515
|
+
|
|
516
|
+
def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
|
|
517
|
+
A = A.type(out_dtype)
|
|
518
|
+
B = B.type(out_dtype)
|
|
519
|
+
|
|
520
|
+
if A_scale is not None:
|
|
521
|
+
A *= A_scale
|
|
522
|
+
if B_scale is not None:
|
|
523
|
+
B *= B_scale.to(out_dtype)
|
|
524
|
+
|
|
525
|
+
output = torch.nn.functional.linear(A, B, bias=bias)
|
|
526
|
+
return output
|
|
527
|
+
|
|
528
|
+
def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
if self.input_scale:
|
|
530
|
+
input = static_per_tensor_quantize(x, self.input_scale)
|
|
531
|
+
else:
|
|
532
|
+
input = x
|
|
533
|
+
|
|
534
|
+
if self.weight_scale:
|
|
535
|
+
# broadcast weight_scale to vector
|
|
536
|
+
weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
|
|
537
|
+
else:
|
|
538
|
+
weight_scale = None
|
|
539
|
+
output = fp8_gemm(
|
|
540
|
+
A=input,
|
|
541
|
+
A_scale=self.input_scale,
|
|
542
|
+
B=self.weight,
|
|
543
|
+
B_scale=weight_scale,
|
|
544
|
+
bias=self.bias,
|
|
545
|
+
out_dtype=x.dtype,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
return output
|
|
549
|
+
|
|
550
|
+
layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
|
|
551
|
+
layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
552
|
+
|
|
553
|
+
if rbln_quantization.activations == "fp8":
|
|
554
|
+
layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
555
|
+
else:
|
|
556
|
+
layer.input_scale = None
|
|
557
|
+
|
|
558
|
+
if rbln_quantization.kv_caches == "fp8":
|
|
559
|
+
layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
560
|
+
layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
561
|
+
|
|
562
|
+
layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
|
|
563
|
+
|
|
564
|
+
return layer
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
import threading
|
|
17
|
-
from typing import Any,
|
|
17
|
+
from typing import Any, List, Optional, Union
|
|
18
18
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
@@ -94,7 +94,7 @@ class RBLNPytorchRuntime:
|
|
|
94
94
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
95
95
|
return self.forward(*args, **kwds)
|
|
96
96
|
|
|
97
|
-
def forward(self, *args: List["torch.Tensor"], **kwargs:
|
|
97
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
|
|
98
98
|
# filtering useless args or kwarg such as None.
|
|
99
99
|
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
|
100
100
|
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
|
|
@@ -142,7 +142,7 @@ class UnavailableRuntime:
|
|
|
142
142
|
"""Returns an iterator with self as the only item."""
|
|
143
143
|
return iter([self])
|
|
144
144
|
|
|
145
|
-
def forward(self, *args: List["torch.Tensor"], **kwargs:
|
|
145
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
|
|
146
146
|
"""Raises a detailed RuntimeError explaining why inference cannot be performed."""
|
|
147
147
|
raise RuntimeError(
|
|
148
148
|
"Cannot perform inference: RBLN runtime is not available.\n\n"
|
optimum/rbln/utils/submodule.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
|
17
17
|
|
|
18
18
|
from transformers import PretrainedConfig
|
|
19
19
|
|
|
@@ -22,7 +22,7 @@ from ..utils.model_utils import get_rbln_model_cls
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
|
-
from transformers import PreTrainedModel
|
|
25
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
26
26
|
|
|
27
27
|
from ..modeling import RBLNModel
|
|
28
28
|
|
|
@@ -42,7 +42,12 @@ class SubModulesMixin:
|
|
|
42
42
|
setattr(self, submodule_meta["name"], submodule)
|
|
43
43
|
|
|
44
44
|
@classmethod
|
|
45
|
-
def _update_submodule_config(
|
|
45
|
+
def _update_submodule_config(
|
|
46
|
+
cls,
|
|
47
|
+
model: "PreTrainedModel",
|
|
48
|
+
rbln_config: RBLNModelConfig,
|
|
49
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
50
|
+
):
|
|
46
51
|
return rbln_config
|
|
47
52
|
|
|
48
53
|
@classmethod
|
|
@@ -51,6 +56,7 @@ class SubModulesMixin:
|
|
|
51
56
|
) -> List["RBLNModel"]:
|
|
52
57
|
rbln_submodules = []
|
|
53
58
|
submodule_prefix = getattr(cls, "_rbln_submodule_prefix", None)
|
|
59
|
+
preprocessors = kwargs.pop("preprocessors", [])
|
|
54
60
|
|
|
55
61
|
for submodule in cls._rbln_submodules:
|
|
56
62
|
submodule_name = submodule["name"]
|
|
@@ -69,7 +75,7 @@ class SubModulesMixin:
|
|
|
69
75
|
submodule_rbln_config = submodule_rbln_config_class(**submodule_rbln_config)
|
|
70
76
|
setattr(rbln_config, submodule_name, submodule_rbln_config)
|
|
71
77
|
|
|
72
|
-
submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config)
|
|
78
|
+
submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config, preprocessors)
|
|
73
79
|
|
|
74
80
|
rbln_submodule = submodule_cls.from_model(
|
|
75
81
|
model=torch_submodule,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: optimum-rbln
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.3
|
|
4
4
|
Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|