optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,37 +14,67 @@
|
|
|
14
14
|
|
|
15
15
|
import glob
|
|
16
16
|
import os
|
|
17
|
-
from typing import Any, Dict, Optional, Union
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
20
21
|
from safetensors.torch import load_file
|
|
21
22
|
from torch.nn import Linear, Parameter
|
|
22
23
|
from torch.nn import functional as F
|
|
24
|
+
from transformers import AutoConfig
|
|
25
|
+
from transformers.modeling_utils import get_state_dict_dtype, no_init_weights
|
|
23
26
|
|
|
24
27
|
from ...configuration_utils import RBLNSerializableConfigProtocol
|
|
25
28
|
from ...utils.logging import get_logger
|
|
26
29
|
|
|
27
30
|
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers.models.auto.modeling_auto import _BaseAutoModelClass
|
|
33
|
+
|
|
28
34
|
logger = get_logger()
|
|
29
35
|
|
|
30
36
|
|
|
37
|
+
# Constants
|
|
38
|
+
QUANTIZED_WEIGHTS = {
|
|
39
|
+
"q_proj",
|
|
40
|
+
"k_proj",
|
|
41
|
+
"v_proj",
|
|
42
|
+
"o_proj",
|
|
43
|
+
"gate_proj",
|
|
44
|
+
"up_proj",
|
|
45
|
+
"down_proj",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
# Common alias sets seen in community checkpoints
|
|
49
|
+
VARIANT_ALIASES: Dict[str, List[str]] = {
|
|
50
|
+
"weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
|
|
51
|
+
"input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
|
|
52
|
+
"kv_scale": ["kv_scale", "kv_scales"],
|
|
53
|
+
"k_scale": ["k_scale", "k_scales"],
|
|
54
|
+
"v_scale": ["v_scale", "v_scales"],
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
31
58
|
class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
32
59
|
SUPPORTED_FORMATS = ["rbln"]
|
|
33
|
-
SUPPORTED_WEIGHTS = ["int4", "fp16"]
|
|
34
|
-
SUPPORTED_ACTIVATIONS = ["fp16"]
|
|
35
|
-
|
|
36
|
-
# The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
|
|
37
|
-
# It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
|
|
60
|
+
SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
|
|
61
|
+
SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
|
|
62
|
+
SUPPORTED_KVCACHES = ["fp8", "fp16"]
|
|
38
63
|
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
|
39
64
|
|
|
40
65
|
def __init__(
|
|
41
66
|
self,
|
|
42
67
|
format: Optional[str] = None,
|
|
43
|
-
precision: Optional[str] = None,
|
|
44
68
|
weights: Optional[str] = None,
|
|
45
69
|
activations: Optional[str] = None,
|
|
70
|
+
kv_caches: Optional[str] = None,
|
|
71
|
+
*,
|
|
72
|
+
precision: Optional[str] = None,
|
|
46
73
|
):
|
|
47
|
-
self.format = format
|
|
74
|
+
self.format = format or "rbln"
|
|
75
|
+
if self.format not in self.SUPPORTED_FORMATS:
|
|
76
|
+
raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
|
|
77
|
+
|
|
48
78
|
if precision is not None:
|
|
49
79
|
logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
|
|
50
80
|
if any(precision_arg is not None for precision_arg in (weights, activations)):
|
|
@@ -58,6 +88,7 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
58
88
|
|
|
59
89
|
self.weights = weights or "fp16"
|
|
60
90
|
self.activations = activations or "fp16"
|
|
91
|
+
self.kv_caches = kv_caches or "fp16"
|
|
61
92
|
self._validate()
|
|
62
93
|
|
|
63
94
|
def _validate(self):
|
|
@@ -69,106 +100,135 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
|
|
|
69
100
|
raise ValueError(
|
|
70
101
|
f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
|
|
71
102
|
)
|
|
103
|
+
if self.kv_caches not in self.SUPPORTED_KVCACHES:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
|
|
106
|
+
)
|
|
72
107
|
if self.weights == "fp16" and self.activations == "fp16":
|
|
73
|
-
raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
|
|
108
|
+
raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
|
|
74
109
|
|
|
75
110
|
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
|
76
111
|
return {
|
|
77
112
|
"format": self.format,
|
|
78
113
|
"weights": self.weights,
|
|
79
114
|
"activations": self.activations,
|
|
115
|
+
"kv_caches": self.kv_caches,
|
|
80
116
|
}
|
|
81
117
|
|
|
82
118
|
def maybe_set_quantization_env(self):
|
|
83
|
-
quant_bits = None
|
|
84
119
|
if self.weights == "int4":
|
|
85
|
-
|
|
86
|
-
os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
|
|
120
|
+
os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
|
|
87
121
|
|
|
88
122
|
def maybe_reset_quantization_env(self):
|
|
89
123
|
if self.RBLN_QUANT_BITS_ENV in os.environ:
|
|
90
124
|
os.environ.pop(self.RBLN_QUANT_BITS_ENV)
|
|
91
125
|
|
|
126
|
+
@property
|
|
127
|
+
def nbits_per_param(self) -> int:
|
|
128
|
+
if self.weights in ["int4", "fp4"]:
|
|
129
|
+
return 4
|
|
130
|
+
elif self.weights in ["int8", "fp8"]:
|
|
131
|
+
return 8
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Invalid weights: {self.weights}")
|
|
92
134
|
|
|
93
|
-
# Constants
|
|
94
|
-
QUANTIZED_WEIGHTS = {
|
|
95
|
-
"q_proj",
|
|
96
|
-
"k_proj",
|
|
97
|
-
"v_proj",
|
|
98
|
-
"o_proj",
|
|
99
|
-
"gate_proj",
|
|
100
|
-
"up_proj",
|
|
101
|
-
"down_proj",
|
|
102
|
-
}
|
|
103
135
|
|
|
136
|
+
class QuantizedLayerFactory:
|
|
137
|
+
def __init__(self, quantization_config: RBLNQuantizationConfig):
|
|
138
|
+
self.quantization_config = quantization_config
|
|
104
139
|
|
|
105
|
-
def
|
|
106
|
-
|
|
140
|
+
def create_linear(self, layer: Linear) -> Linear:
|
|
141
|
+
if self.quantization_config.weights in ["int4", "int8"]:
|
|
142
|
+
return self.create_qlinear(layer)
|
|
143
|
+
elif self.quantization_config.weights == "fp8":
|
|
144
|
+
return self.create_fp8linear(layer)
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
|
|
147
|
+
|
|
148
|
+
def create_qlinear(self, layer: Linear) -> Linear:
|
|
149
|
+
return create_qlinear(layer, self.quantization_config)
|
|
150
|
+
|
|
151
|
+
def create_fp8linear(self, layer: Linear) -> Linear:
|
|
152
|
+
return create_fp8linear(layer, self.quantization_config)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def get_quantized_model(
|
|
156
|
+
hf_auto_model_class: Type["_BaseAutoModelClass"],
|
|
107
157
|
model_id: str,
|
|
108
|
-
n_layer: Optional[int] = None,
|
|
109
158
|
use_auth_token: Optional[Union[bool, str]] = None,
|
|
110
159
|
revision: Optional[str] = None,
|
|
111
160
|
cache_dir: Optional[str] = None,
|
|
112
161
|
force_download: bool = False,
|
|
113
162
|
local_files_only: bool = False,
|
|
114
|
-
|
|
163
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
164
|
+
**kwargs,
|
|
165
|
+
):
|
|
115
166
|
"""
|
|
116
|
-
|
|
167
|
+
Get a quantized model from a model class and model id.
|
|
117
168
|
"""
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
169
|
+
# torch_dtype should not be passed to AutoConfig.from_pretrained
|
|
170
|
+
# since it doesn't support 'auto'
|
|
171
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
172
|
+
if torch_dtype is not None:
|
|
173
|
+
logger.warning(
|
|
174
|
+
"torch_dtype is not supported for quantized models. "
|
|
175
|
+
"It will be ignored and the dtype of the model will be determined by the weights."
|
|
176
|
+
)
|
|
177
|
+
torch_dtype = None
|
|
178
|
+
|
|
179
|
+
# get paths of safetensors files in the model repo
|
|
180
|
+
safetensor_files = load_weight_files(
|
|
121
181
|
model_id,
|
|
122
|
-
n_layer,
|
|
123
182
|
use_auth_token=use_auth_token,
|
|
124
183
|
revision=revision,
|
|
125
184
|
cache_dir=cache_dir,
|
|
126
185
|
force_download=force_download,
|
|
127
186
|
local_files_only=local_files_only,
|
|
128
187
|
)
|
|
129
|
-
return model
|
|
130
188
|
|
|
189
|
+
# load safetensors files into memory
|
|
190
|
+
safetensors = [load_file(safetensor_file) for safetensor_file in safetensor_files]
|
|
131
191
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
135
|
-
"""
|
|
192
|
+
# get the dtype of the model from the first safetensor file
|
|
193
|
+
torch_dtype = get_state_dict_dtype(safetensors[0])
|
|
136
194
|
|
|
137
|
-
|
|
138
|
-
|
|
195
|
+
config = AutoConfig.from_pretrained(
|
|
196
|
+
model_id,
|
|
197
|
+
use_auth_token=use_auth_token,
|
|
198
|
+
revision=revision,
|
|
199
|
+
cache_dir=cache_dir,
|
|
200
|
+
force_download=force_download,
|
|
201
|
+
local_files_only=local_files_only,
|
|
202
|
+
**kwargs,
|
|
203
|
+
)
|
|
139
204
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
parent_module, layer_name = get_parent_and_child(module, name)
|
|
143
|
-
setattr(parent_module, layer_name, create_qlinear(layer))
|
|
144
|
-
processed_layers.append(name)
|
|
205
|
+
with no_init_weights():
|
|
206
|
+
model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype)
|
|
145
207
|
|
|
146
|
-
|
|
147
|
-
|
|
208
|
+
# Quantize the model
|
|
209
|
+
update_layers_to_quantize(model, rbln_quantization)
|
|
148
210
|
|
|
211
|
+
# Load weights into the model
|
|
212
|
+
load_weights_from_files(model, safetensors, rbln_quantization)
|
|
149
213
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
214
|
+
return model
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def load_weight_files(
|
|
218
|
+
model_id: str,
|
|
219
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
220
|
+
revision: Optional[str] = None,
|
|
221
|
+
cache_dir: Optional[str] = None,
|
|
222
|
+
force_download: bool = False,
|
|
223
|
+
local_files_only: bool = False,
|
|
224
|
+
) -> list[str]:
|
|
160
225
|
"""
|
|
161
|
-
|
|
226
|
+
Discover and download safetensors files for the given model id.
|
|
162
227
|
"""
|
|
163
228
|
|
|
164
|
-
model_params = dict(model.named_parameters(recurse=True))
|
|
165
|
-
model_buffers = dict(model.named_buffers(recurse=True))
|
|
166
|
-
|
|
167
229
|
if os.path.isdir(model_id):
|
|
168
230
|
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
|
169
231
|
else:
|
|
170
|
-
from huggingface_hub import hf_hub_download, list_repo_files
|
|
171
|
-
|
|
172
232
|
try:
|
|
173
233
|
# List all files in the repository
|
|
174
234
|
repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
|
|
@@ -195,27 +255,226 @@ def load_weights(
|
|
|
195
255
|
if not safetensor_files:
|
|
196
256
|
raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
|
|
197
257
|
|
|
198
|
-
|
|
258
|
+
return safetensor_files
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def update_layers_to_quantize(
|
|
262
|
+
module: torch.nn.Module,
|
|
263
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
264
|
+
) -> None:
|
|
265
|
+
"""
|
|
266
|
+
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
processed_layers = []
|
|
270
|
+
quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
|
|
271
|
+
|
|
272
|
+
for name, layer in module.named_modules():
|
|
273
|
+
if is_target_for_qlinear_replacement(name, layer):
|
|
274
|
+
parent_module, layer_name = get_parent_and_child(module, name)
|
|
275
|
+
setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
|
|
276
|
+
processed_layers.append(name)
|
|
277
|
+
|
|
278
|
+
if processed_layers:
|
|
279
|
+
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _last_segment(key: str) -> str:
|
|
283
|
+
parts = key.split(".")
|
|
284
|
+
return parts[-1]
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _replace_last_with(key: str, new_tail: str) -> str:
|
|
288
|
+
parts = key.split(".")
|
|
289
|
+
return ".".join(parts[:-1] + new_tail.split("."))
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _matches_any_alias(key: str, kind: str) -> bool:
|
|
293
|
+
tail = _last_segment(key)
|
|
294
|
+
return tail in VARIANT_ALIASES.get(kind, [])
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
|
|
298
|
+
if t.ndim == 0:
|
|
299
|
+
return t
|
|
300
|
+
return t.reshape(-1).amax()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
|
|
304
|
+
s = scale
|
|
305
|
+
if s.ndim == 0:
|
|
306
|
+
# scalar -> expand to [out_features, 1]
|
|
307
|
+
return s.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
308
|
+
if s.ndim == 1:
|
|
309
|
+
if s.numel() == 1:
|
|
310
|
+
return s.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
311
|
+
if s.numel() == out_features:
|
|
312
|
+
return s.reshape(out_features, 1).contiguous()
|
|
313
|
+
# fallback: reduce to scalar then expand
|
|
314
|
+
v = _reduce_to_scalar(s)
|
|
315
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
316
|
+
if s.ndim == 2:
|
|
317
|
+
if s.shape == (out_features, 1):
|
|
318
|
+
return s.contiguous()
|
|
319
|
+
if s.shape == (1, out_features):
|
|
320
|
+
return s.transpose(0, 1).contiguous()
|
|
321
|
+
# fallback: reduce to [out_features] on non-out dims if possible
|
|
322
|
+
if s.shape[0] == out_features:
|
|
323
|
+
v = s
|
|
324
|
+
while v.ndim > 2:
|
|
325
|
+
v = v.amax(dim=-1)
|
|
326
|
+
if v.shape[-1] != 1:
|
|
327
|
+
v = v.amax(dim=-1, keepdim=True)
|
|
328
|
+
return v.contiguous()
|
|
329
|
+
# otherwise reduce to scalar then expand
|
|
330
|
+
v = _reduce_to_scalar(s)
|
|
331
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
332
|
+
# high-rank: reduce to scalar then expand
|
|
333
|
+
v = _reduce_to_scalar(s)
|
|
334
|
+
return v.reshape(1, 1).expand(out_features, 1).contiguous()
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
|
|
338
|
+
# base_key is the original key whose last token was 'kv_scale'
|
|
339
|
+
# We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
|
|
340
|
+
if tensor.ndim == 1 and tensor.numel() >= 2:
|
|
341
|
+
tk, tv = tensor[0], tensor[1]
|
|
342
|
+
elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
|
|
343
|
+
tk, tv = tensor[0, 0], tensor[1, 0]
|
|
344
|
+
else:
|
|
345
|
+
tk = tv = tensor
|
|
346
|
+
k_key = _replace_last_with(base_key, "k_proj.k_scale")
|
|
347
|
+
v_key = _replace_last_with(base_key, "v_proj.v_scale")
|
|
348
|
+
return [(k_key, tk), (v_key, tv)]
|
|
199
349
|
|
|
200
|
-
unloaded_keys = []
|
|
201
|
-
for safetensor_file in safetensor_files:
|
|
202
|
-
file_data = load_file(safetensor_file)
|
|
203
|
-
for key, value in file_data.items():
|
|
204
|
-
if target_layers is not None:
|
|
205
|
-
parts = key.split(".")
|
|
206
350
|
|
|
207
|
-
|
|
208
|
-
|
|
351
|
+
def canonicalize_checkpoint_items(
|
|
352
|
+
model: torch.nn.Module,
|
|
353
|
+
items: Iterable[Tuple[str, torch.Tensor]],
|
|
354
|
+
rbln_quantization: Optional[RBLNQuantizationConfig],
|
|
355
|
+
) -> List[Tuple[str, torch.Tensor]]:
|
|
356
|
+
params = dict(model.named_parameters(recurse=True))
|
|
357
|
+
results: List[Tuple[str, torch.Tensor]] = []
|
|
358
|
+
|
|
359
|
+
for key, value in items:
|
|
360
|
+
t = value
|
|
361
|
+
# Normalize weight scale variants
|
|
362
|
+
if _matches_any_alias(key, "weight_scale"):
|
|
363
|
+
# rename last token to the canonical weight scale key
|
|
364
|
+
target_key = _replace_last_with(key, "weight_scale")
|
|
365
|
+
|
|
366
|
+
# Determine associated weight param to infer shape
|
|
367
|
+
weight_key = _replace_last_with(target_key, "weight")
|
|
368
|
+
out_features = None
|
|
369
|
+
if weight_key in params:
|
|
370
|
+
wshape = params[weight_key].shape
|
|
371
|
+
if len(wshape) == 2:
|
|
372
|
+
out_features = int(wshape[0])
|
|
373
|
+
|
|
374
|
+
if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
|
|
375
|
+
t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
|
|
376
|
+
elif rbln_quantization.weights == "fp8":
|
|
377
|
+
# Use a conservative scalar scale to ensure broadcastability
|
|
378
|
+
t = _reduce_to_scalar(t.to(torch.float32))
|
|
379
|
+
else:
|
|
380
|
+
t = t.to(torch.float32)
|
|
381
|
+
|
|
382
|
+
results.append((target_key, t))
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# Normalize input/activation scale variants
|
|
386
|
+
if _matches_any_alias(key, "input_scale"):
|
|
387
|
+
target_key = _replace_last_with(key, "input_scale")
|
|
388
|
+
t = _reduce_to_scalar(t.to(torch.float32))
|
|
389
|
+
results.append((target_key, t))
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
# KV scale handling
|
|
393
|
+
if _matches_any_alias(key, "kv_scale"):
|
|
394
|
+
# For quark-like formats, expand to k/v
|
|
395
|
+
kv_items = _kv_split_items(key, t.to(torch.float32))
|
|
396
|
+
for k2, v2 in kv_items:
|
|
397
|
+
results.append((k2, v2))
|
|
398
|
+
continue
|
|
399
|
+
|
|
400
|
+
if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
|
|
401
|
+
results.append((key, t.to(torch.float32)))
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
# Default: passthrough
|
|
405
|
+
results.append((key, t))
|
|
406
|
+
|
|
407
|
+
return results
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def load_weights_from_files(
|
|
411
|
+
model: torch.nn.Module,
|
|
412
|
+
safetensors: List[Dict[str, torch.Tensor]],
|
|
413
|
+
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
414
|
+
):
|
|
415
|
+
"""
|
|
416
|
+
Load safetensor file data directly into the model from provided safetensor files.
|
|
417
|
+
"""
|
|
418
|
+
|
|
419
|
+
model_params = dict(model.named_parameters(recurse=True))
|
|
420
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
|
209
421
|
|
|
422
|
+
unloaded_keys = []
|
|
423
|
+
loaded_input_scale = False
|
|
424
|
+
loaded_kv_scale = False
|
|
425
|
+
loaded_weight_scale = False
|
|
426
|
+
|
|
427
|
+
for safetensor in safetensors:
|
|
428
|
+
# Normalize all (key, tensor) pairs to the internal schema
|
|
429
|
+
normalized_items = canonicalize_checkpoint_items(
|
|
430
|
+
model=model,
|
|
431
|
+
items=safetensor.items(),
|
|
432
|
+
rbln_quantization=rbln_quantization,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
for key, value in normalized_items:
|
|
436
|
+
# Track which types of scales were observed (post-normalization)
|
|
437
|
+
if key.endswith("input_scale"):
|
|
438
|
+
loaded_input_scale = True
|
|
439
|
+
if key.endswith("weight_scale"):
|
|
440
|
+
loaded_weight_scale = True
|
|
441
|
+
if key.endswith("k_scale") or key.endswith("v_scale"):
|
|
442
|
+
loaded_kv_scale = True
|
|
443
|
+
|
|
444
|
+
# Copy into parameters or buffers
|
|
210
445
|
if key in model_params:
|
|
446
|
+
# Ensure dtype compatibility
|
|
447
|
+
if model_params[key].dtype != value.dtype:
|
|
448
|
+
value = value.to(model_params[key].dtype)
|
|
211
449
|
model_params[key].data.copy_(value)
|
|
212
450
|
elif key in model_buffers:
|
|
451
|
+
if model_buffers[key].dtype != value.dtype:
|
|
452
|
+
value = value.to(model_buffers[key].dtype)
|
|
213
453
|
model_buffers[key].data.copy_(value)
|
|
214
454
|
else:
|
|
215
455
|
unloaded_keys.append(key)
|
|
216
456
|
|
|
217
457
|
if len(unloaded_keys) > 0:
|
|
218
458
|
logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
|
|
459
|
+
if not loaded_input_scale and rbln_quantization.activations == "fp8":
|
|
460
|
+
raise ValueError(
|
|
461
|
+
"No input_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
462
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
463
|
+
)
|
|
464
|
+
if not loaded_weight_scale and rbln_quantization.weights == "fp8":
|
|
465
|
+
raise ValueError(
|
|
466
|
+
"No weight_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
467
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
468
|
+
)
|
|
469
|
+
if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
|
|
470
|
+
raise ValueError(
|
|
471
|
+
"No kv_scale found in the checkpoint. Did you use the correct quantization config? "
|
|
472
|
+
"If you are using fp8 quantization, you need to use the correct quantization config."
|
|
473
|
+
)
|
|
474
|
+
if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
|
|
475
|
+
logger.warning(
|
|
476
|
+
"kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
|
|
477
|
+
)
|
|
219
478
|
|
|
220
479
|
|
|
221
480
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
|
@@ -225,6 +484,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
|
|
|
225
484
|
return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
|
|
226
485
|
|
|
227
486
|
|
|
487
|
+
def is_target_for_adding_kv_scales(layer_name: str) -> bool:
|
|
488
|
+
return layer_name.split(".")[-1] in ["self_attn"]
|
|
489
|
+
|
|
490
|
+
|
|
228
491
|
def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
|
|
229
492
|
"""
|
|
230
493
|
Splits the full layer name to retrieve the parent module and the child layer.
|
|
@@ -243,22 +506,84 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
|
|
|
243
506
|
return obj
|
|
244
507
|
|
|
245
508
|
|
|
246
|
-
def create_qlinear(layer: Linear) -> Linear:
|
|
509
|
+
def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
247
510
|
"""
|
|
248
511
|
Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
|
|
249
512
|
"""
|
|
250
513
|
|
|
251
514
|
def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
252
|
-
|
|
253
|
-
|
|
515
|
+
weight_scale = self.weight_scale
|
|
516
|
+
if inputs.dtype != weight_scale.dtype:
|
|
517
|
+
raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
|
|
254
518
|
|
|
255
519
|
w_fp = self.weight.type(inputs.dtype)
|
|
256
|
-
w_fp *=
|
|
520
|
+
w_fp *= weight_scale.view(-1, 1)
|
|
257
521
|
return F.linear(inputs, w_fp, self.bias)
|
|
258
522
|
|
|
259
523
|
# Convert weight to int8 and add scale parameter
|
|
260
524
|
layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
|
|
261
|
-
layer.
|
|
525
|
+
layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
|
|
262
526
|
layer.forward = lambda inputs: qlinear_forward(layer, inputs)
|
|
263
527
|
|
|
264
528
|
return layer
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
|
|
532
|
+
"""
|
|
533
|
+
Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
|
|
537
|
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
538
|
+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
539
|
+
return qweight
|
|
540
|
+
|
|
541
|
+
def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
|
|
542
|
+
A = A.type(out_dtype)
|
|
543
|
+
B = B.type(out_dtype)
|
|
544
|
+
|
|
545
|
+
if A_scale is not None:
|
|
546
|
+
A *= A_scale
|
|
547
|
+
if B_scale is not None:
|
|
548
|
+
B *= B_scale.to(out_dtype)
|
|
549
|
+
|
|
550
|
+
output = torch.nn.functional.linear(A, B, bias=bias)
|
|
551
|
+
return output
|
|
552
|
+
|
|
553
|
+
def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
554
|
+
if self.input_scale:
|
|
555
|
+
input = static_per_tensor_quantize(x, self.input_scale)
|
|
556
|
+
else:
|
|
557
|
+
input = x
|
|
558
|
+
|
|
559
|
+
if self.weight_scale:
|
|
560
|
+
# broadcast weight_scale to vector
|
|
561
|
+
weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
|
|
562
|
+
else:
|
|
563
|
+
weight_scale = None
|
|
564
|
+
output = fp8_gemm(
|
|
565
|
+
A=input,
|
|
566
|
+
A_scale=self.input_scale,
|
|
567
|
+
B=self.weight,
|
|
568
|
+
B_scale=weight_scale,
|
|
569
|
+
bias=self.bias,
|
|
570
|
+
out_dtype=x.dtype,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
return output
|
|
574
|
+
|
|
575
|
+
layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
|
|
576
|
+
layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
577
|
+
|
|
578
|
+
if rbln_quantization.activations == "fp8":
|
|
579
|
+
layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
580
|
+
else:
|
|
581
|
+
layer.input_scale = None
|
|
582
|
+
|
|
583
|
+
if rbln_quantization.kv_caches == "fp8":
|
|
584
|
+
layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
585
|
+
layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
|
|
586
|
+
|
|
587
|
+
layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
|
|
588
|
+
|
|
589
|
+
return layer
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
|
18
|
+
|
|
19
|
+
from torch.nn import Module
|
|
20
|
+
|
|
21
|
+
from ...modeling import RBLNModel
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
import rebel
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LoopProcessor(Module, ABC):
|
|
29
|
+
def __init__(self, model: Union[RBLNModel, "rebel.Runtime"]):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.model = model
|
|
32
|
+
|
|
33
|
+
def __repr__(self) -> str:
|
|
34
|
+
return repr(self.model)
|
|
35
|
+
|
|
36
|
+
def _is_batch_implemented(self) -> bool:
|
|
37
|
+
return self._forward_batch.__func__ is not LoopProcessor._forward_batch
|
|
38
|
+
|
|
39
|
+
def forward(self, *args, force_loop: bool = False, **kwargs) -> Any:
|
|
40
|
+
if not force_loop and self._is_batch_implemented():
|
|
41
|
+
return self._forward_batch(*args, **kwargs)
|
|
42
|
+
else:
|
|
43
|
+
return self._forward_loop(*args, **kwargs)
|
|
44
|
+
|
|
45
|
+
def _forward_loop(self, *args, **kwargs) -> Any:
|
|
46
|
+
batch_size = self._get_batch_size(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
if not isinstance(batch_size, int) or batch_size == 0:
|
|
49
|
+
return self._process_outputs([])
|
|
50
|
+
|
|
51
|
+
common_inputs = self._prepare_inputs_before_loop(*args, **kwargs)
|
|
52
|
+
|
|
53
|
+
outputs = []
|
|
54
|
+
for i in range(batch_size):
|
|
55
|
+
item_args, item_kwargs = self._prepare_inputs_for_iteration(i, common_inputs, *args, **kwargs)
|
|
56
|
+
item_output = self.model(*item_args, **item_kwargs)
|
|
57
|
+
outputs.append(item_output)
|
|
58
|
+
|
|
59
|
+
return self._process_outputs(outputs, **kwargs)
|
|
60
|
+
|
|
61
|
+
def _forward_batch(self, *args, **kwargs) -> Any:
|
|
62
|
+
raise NotImplementedError("The batch processing logic (_forward_batch) is not implemented in this class.")
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def _get_batch_size(self, *args, **kwargs) -> int:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def _prepare_inputs_for_iteration(
|
|
70
|
+
self, index: int, common_inputs: Dict[str, Any], *args, **kwargs
|
|
71
|
+
) -> Tuple[List[Any], Dict[str, Any]]:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
def _prepare_inputs_before_loop(self, *args, **kwargs) -> Dict[str, Any]:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def _process_outputs(self, outputs: List[Any], **kwargs) -> Any:
|
|
79
|
+
pass
|