angelslim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- angelslim/__init__.py +15 -0
- angelslim/compressor/__init__.py +15 -0
- angelslim/compressor/compressor_factory.py +83 -0
- angelslim/compressor/distill/__init__.py +13 -0
- angelslim/compressor/quant/__init__.py +18 -0
- angelslim/compressor/quant/core/__init__.py +25 -0
- angelslim/compressor/quant/core/config.py +166 -0
- angelslim/compressor/quant/core/dit_hook.py +92 -0
- angelslim/compressor/quant/core/hook.py +184 -0
- angelslim/compressor/quant/core/metrics.py +88 -0
- angelslim/compressor/quant/core/packing_utils.py +117 -0
- angelslim/compressor/quant/core/quant_func.py +316 -0
- angelslim/compressor/quant/core/sample_func.py +45 -0
- angelslim/compressor/quant/core/save.py +211 -0
- angelslim/compressor/quant/modules/__init__.py +25 -0
- angelslim/compressor/quant/modules/awq/__init__.py +13 -0
- angelslim/compressor/quant/modules/awq/auto_clip.py +290 -0
- angelslim/compressor/quant/modules/awq/auto_scale.py +222 -0
- angelslim/compressor/quant/modules/awq/awq.py +334 -0
- angelslim/compressor/quant/modules/awq/search.py +137 -0
- angelslim/compressor/quant/modules/catcher.py +33 -0
- angelslim/compressor/quant/modules/fp8/__init__.py +13 -0
- angelslim/compressor/quant/modules/fp8/fp8.py +142 -0
- angelslim/compressor/quant/modules/gptq/__init__.py +13 -0
- angelslim/compressor/quant/modules/gptq/gptq.py +319 -0
- angelslim/compressor/quant/modules/gptq/gptq_module.py +219 -0
- angelslim/compressor/quant/modules/helper_layer.py +628 -0
- angelslim/compressor/quant/modules/int8/__init__.py +13 -0
- angelslim/compressor/quant/modules/int8/int8.py +142 -0
- angelslim/compressor/quant/modules/smooth/__init__.py +13 -0
- angelslim/compressor/quant/modules/smooth/smooth.py +88 -0
- angelslim/compressor/quant/observers/__init__.py +23 -0
- angelslim/compressor/quant/observers/abs_max_activation.py +216 -0
- angelslim/compressor/quant/observers/abs_max_weight.py +77 -0
- angelslim/compressor/quant/observers/base_observer.py +115 -0
- angelslim/compressor/quant/observers/ema_activation.py +63 -0
- angelslim/compressor/quant/observers/groupwise_weight.py +85 -0
- angelslim/compressor/quant/observers/hist_activation.py +244 -0
- angelslim/compressor/quant/observers/observer.py +70 -0
- angelslim/compressor/quant/ptq.py +188 -0
- angelslim/compressor/sparsity/__init__.py +13 -0
- angelslim/compressor/speculative_decoding/__init__.py +13 -0
- angelslim/data/__init__.py +9 -0
- angelslim/data/base_dataset.py +54 -0
- angelslim/data/dataloader.py +95 -0
- angelslim/data/multimodal_dataset.py +153 -0
- angelslim/data/text_dataset.py +162 -0
- angelslim/engine.py +183 -0
- angelslim/models/__init__.py +17 -0
- angelslim/models/base_model.py +309 -0
- angelslim/models/llm/__init__.py +19 -0
- angelslim/models/llm/deepseek.py +79 -0
- angelslim/models/llm/hunyuan_dense.py +71 -0
- angelslim/models/llm/hunyuan_moe.py +94 -0
- angelslim/models/llm/llama.py +68 -0
- angelslim/models/llm/qwen.py +92 -0
- angelslim/models/model_factory.py +69 -0
- angelslim/models/vlm/__init__.py +16 -0
- angelslim/models/vlm/qwen_vl.py +78 -0
- angelslim/utils/__init__.py +25 -0
- angelslim/utils/config_parser.py +389 -0
- angelslim/utils/default_compress_config.py +128 -0
- angelslim/utils/utils.py +134 -0
- angelslim-0.1.0.dist-info/METADATA +45 -0
- angelslim-0.1.0.dist-info/RECORD +68 -0
- angelslim-0.1.0.dist-info/WHEEL +5 -0
- angelslim-0.1.0.dist-info/licenses/LICENSE +87 -0
- angelslim-0.1.0.dist-info/top_level.txt +1 -0
angelslim/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .engine import Engine # noqa: F401
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .compressor_factory import CompressorFactory # noqa: F401
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Callable, Dict, Optional, Type, Union
|
|
16
|
+
|
|
17
|
+
from ..utils import print_info
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CompressorFactory:
|
|
21
|
+
"""
|
|
22
|
+
Factory class for model compression methods with flexible registration.
|
|
23
|
+
Supports both explicit name registration and direct class name registration.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
_compress_methods: Dict[str, Type[Any]] = {}
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def register(cls, name: Optional[Union[str, Callable]] = None) -> Callable:
|
|
30
|
+
"""Decorator to register compression methods. Supports two usage patterns:
|
|
31
|
+
1. @CompressorFactory.register("explicit_name")
|
|
32
|
+
2. @CompressorFactory.register (uses class name as key)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Handler for direct class registration (@CompressorFactory.register)
|
|
36
|
+
def register_class(compress_cls: Type[Any]) -> Type[Any]:
|
|
37
|
+
"""Register a class using its own name as the key"""
|
|
38
|
+
key = compress_cls.__name__
|
|
39
|
+
if key in cls._compress_methods:
|
|
40
|
+
print_info(
|
|
41
|
+
f"Compression method '{key}' already exists, will be overwritten."
|
|
42
|
+
)
|
|
43
|
+
cls._compress_methods[key] = compress_cls
|
|
44
|
+
return compress_cls
|
|
45
|
+
|
|
46
|
+
# Handler for named registration (@CompressorFactory.register("name"))
|
|
47
|
+
def register_with_name(key: str) -> Callable[[Type[Any]], Type[Any]]:
|
|
48
|
+
"""Decorator that registers a class with a custom key"""
|
|
49
|
+
|
|
50
|
+
def decorator(compress_cls: Type[Any]) -> Type[Any]:
|
|
51
|
+
if key in cls._compress_methods:
|
|
52
|
+
print_info(f"register '{key}' already exists, will be overwritten.")
|
|
53
|
+
cls._compress_methods[key] = compress_cls
|
|
54
|
+
return compress_cls
|
|
55
|
+
|
|
56
|
+
return decorator
|
|
57
|
+
|
|
58
|
+
# Determine registration type based on input
|
|
59
|
+
if name is None:
|
|
60
|
+
# Case 1: Direct class registration (@CompressorFactory.register)
|
|
61
|
+
return register_class
|
|
62
|
+
elif isinstance(name, str):
|
|
63
|
+
# Case 2: Explicit name registration (@CompressorFactory.register("name"))
|
|
64
|
+
return register_with_name(name)
|
|
65
|
+
elif callable(name):
|
|
66
|
+
# Case 3: Direct class registration (called without parentheses)
|
|
67
|
+
return register_class(name)
|
|
68
|
+
else:
|
|
69
|
+
raise TypeError("Invalid argument type for registration")
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def create(cls, name: str, model: Any, slim_config: Any) -> Any:
|
|
73
|
+
"""Create compressor instance"""
|
|
74
|
+
if name not in cls._compress_methods:
|
|
75
|
+
available = list(cls._compress_methods.keys())
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Compression method '{name}' not registered. Available: {available}"
|
|
78
|
+
)
|
|
79
|
+
return cls._compress_methods[name](model, slim_config)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def get_available_compressor(cls) -> list:
|
|
83
|
+
return list(cls._compress_methods.keys())
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .core import * # noqa: F401 F403
|
|
16
|
+
from .modules import * # noqa: F401 F403
|
|
17
|
+
from .observers import * # noqa: F401 F403
|
|
18
|
+
from .ptq import PTQ # noqa: F401
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .config import * # noqa: F401 F403
|
|
16
|
+
from .hook import DiTHook, PTQHook # noqa: F401
|
|
17
|
+
from .metrics import mse_loss, snr_loss # noqa: F401
|
|
18
|
+
from .packing_utils import dequantize_gemm # noqa: F401
|
|
19
|
+
from .quant_func import * # noqa: F401 F403
|
|
20
|
+
from .sample_func import EMASampler, MultiStepSampler # noqa: F401
|
|
21
|
+
from .save import PTQPTMSave # noqa: F401
|
|
22
|
+
from .save import PTQSaveVllmHF # noqa: F401
|
|
23
|
+
from .save import PTQTorchSave # noqa: F401
|
|
24
|
+
from .save import PTQvLLMSaveHF # noqa: F401
|
|
25
|
+
from .save import PTQVLMSaveVllmHF # noqa: F401
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
from ..observers import (
|
|
18
|
+
AbsMaxChannelWiseWeightObserver,
|
|
19
|
+
AbsMaxGroupWiseWeightObserver,
|
|
20
|
+
AbsmaxPerchannelObserver,
|
|
21
|
+
AbsmaxPertensorObserver,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
ACT_OBSERVERS_CLASS = {
|
|
25
|
+
"per-tensor": AbsmaxPertensorObserver,
|
|
26
|
+
"per-channel": AbsmaxPerchannelObserver,
|
|
27
|
+
}
|
|
28
|
+
WEIGHT_OBSERVERS_CLASS = {
|
|
29
|
+
"per-tensor": AbsmaxPertensorObserver,
|
|
30
|
+
"per-channel": AbsMaxChannelWiseWeightObserver,
|
|
31
|
+
"per-group": AbsMaxGroupWiseWeightObserver,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
KVCACHE_OBSERVERS_CLASS = {"per-channel": AbsmaxPerchannelObserver}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class QuantConfig:
|
|
38
|
+
r"""
|
|
39
|
+
Configure how to quantize a model or a part of the model. It will map each layer to
|
|
40
|
+
an instance of observers by the settings.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config: The quant config.
|
|
44
|
+
Examples:
|
|
45
|
+
.. code-block:: python
|
|
46
|
+
|
|
47
|
+
>>> from slim.quant import QuantConfig
|
|
48
|
+
>>> q_config = QuantConfig(yaml_config)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config, global_config=None):
|
|
52
|
+
# quant_algo change
|
|
53
|
+
self.act_observer = None
|
|
54
|
+
self.weight_observer = None
|
|
55
|
+
self.kv_cache_observer = None
|
|
56
|
+
|
|
57
|
+
quantization_args = config.quantization
|
|
58
|
+
self.quant_algo = quantization_args.name
|
|
59
|
+
self.quant_bit = quantization_args.bits
|
|
60
|
+
self.max_seq_length = global_config.max_seq_length
|
|
61
|
+
self.quant_helpers = quantization_args.quant_helpers
|
|
62
|
+
act_quant_method = quantization_args.quant_method.get("activation", None)
|
|
63
|
+
weight_quant_method = quantization_args.quant_method["weight"]
|
|
64
|
+
|
|
65
|
+
if "fp8" in self.quant_algo:
|
|
66
|
+
is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
|
|
67
|
+
assert (
|
|
68
|
+
is_dynamic or act_quant_method is not None
|
|
69
|
+
), "[OpenSlim][Error] fp8_static need act_quant_method"
|
|
70
|
+
self.act_observer = (
|
|
71
|
+
ACT_OBSERVERS_CLASS[act_quant_method]
|
|
72
|
+
if "static" in is_dynamic
|
|
73
|
+
else None
|
|
74
|
+
)
|
|
75
|
+
self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
|
|
76
|
+
self.kv_cache_observer = None
|
|
77
|
+
self.quant_algo_info = {
|
|
78
|
+
"w": f"fp8_{weight_quant_method}",
|
|
79
|
+
"ignore_layers": quantization_args.ignore_layers,
|
|
80
|
+
}
|
|
81
|
+
if act_quant_method is not None:
|
|
82
|
+
self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}"
|
|
83
|
+
self.hidden_size = global_config.hidden_size
|
|
84
|
+
self.model_arch_type = global_config.model_arch_type
|
|
85
|
+
self.low_memory = config.quantization.low_memory
|
|
86
|
+
elif "int8" in self.quant_algo:
|
|
87
|
+
is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
|
|
88
|
+
assert (
|
|
89
|
+
is_dynamic or act_quant_method is not None
|
|
90
|
+
), "[OpenSlim][Error] int8_static need act_quant_method"
|
|
91
|
+
self.act_observer = (
|
|
92
|
+
ACT_OBSERVERS_CLASS[act_quant_method]
|
|
93
|
+
if "static" in is_dynamic
|
|
94
|
+
else None
|
|
95
|
+
)
|
|
96
|
+
self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
|
|
97
|
+
self.kv_cache_observer = None
|
|
98
|
+
self.quant_algo_info = {
|
|
99
|
+
"w": f"int8_{weight_quant_method}",
|
|
100
|
+
"ignore_layers": quantization_args.ignore_layers,
|
|
101
|
+
}
|
|
102
|
+
if act_quant_method is not None:
|
|
103
|
+
self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}"
|
|
104
|
+
self.hidden_size = global_config.hidden_size
|
|
105
|
+
self.model_arch_type = global_config.model_arch_type
|
|
106
|
+
self.low_memory = config.quantization.low_memory
|
|
107
|
+
elif "int4_awq" in self.quant_algo:
|
|
108
|
+
self.act_observer = None
|
|
109
|
+
self.weight_observer = None
|
|
110
|
+
self.kv_cache_observer = None
|
|
111
|
+
group_size = (
|
|
112
|
+
128
|
|
113
|
+
if quantization_args.quant_method["group_size"] == -1
|
|
114
|
+
else quantization_args.quant_method["group_size"]
|
|
115
|
+
)
|
|
116
|
+
self.quant_algo_info = {
|
|
117
|
+
"zero_point": quantization_args.quant_method["zero_point"],
|
|
118
|
+
"group_size": int(group_size),
|
|
119
|
+
"mse_range": quantization_args.quant_method["mse_range"],
|
|
120
|
+
}
|
|
121
|
+
self.hidden_size = global_config.hidden_size
|
|
122
|
+
self.model_arch_type = global_config.model_arch_type
|
|
123
|
+
elif "int4_gptq" in self.quant_algo:
|
|
124
|
+
self.act_observer = None
|
|
125
|
+
self.weight_observer = None
|
|
126
|
+
self.kv_cache_observer = None
|
|
127
|
+
group_size = (
|
|
128
|
+
128
|
|
129
|
+
if quantization_args.quant_method["group_size"] == -1
|
|
130
|
+
else quantization_args.quant_method["group_size"]
|
|
131
|
+
)
|
|
132
|
+
self.quant_algo_info = {
|
|
133
|
+
"group_size": group_size,
|
|
134
|
+
"ignore_layers": quantization_args.ignore_layers,
|
|
135
|
+
}
|
|
136
|
+
self.hidden_size = global_config.hidden_size
|
|
137
|
+
|
|
138
|
+
if "smooth" in self.quant_helpers:
|
|
139
|
+
self.smooth_alpha = quantization_args.smooth_alpha
|
|
140
|
+
self.smooth_observer = ACT_OBSERVERS_CLASS["per-channel"]
|
|
141
|
+
self.custom_observe_layers_names = "default"
|
|
142
|
+
|
|
143
|
+
def custom_observe_layers(
|
|
144
|
+
self,
|
|
145
|
+
names: List,
|
|
146
|
+
act_observer="default",
|
|
147
|
+
weight_observer="default",
|
|
148
|
+
kv_cache_observer="default",
|
|
149
|
+
):
|
|
150
|
+
"""
|
|
151
|
+
name supports fuzzy search.
|
|
152
|
+
"""
|
|
153
|
+
self.custom_observe_layers_names = names
|
|
154
|
+
self.act_observer = (
|
|
155
|
+
act_observer if act_observer in ACT_OBSERVERS_CLASS else self.act_observer
|
|
156
|
+
)
|
|
157
|
+
self.weight_observer = (
|
|
158
|
+
weight_observer
|
|
159
|
+
if weight_observer in WEIGHT_OBSERVERS_CLASS
|
|
160
|
+
else self.weight_observer
|
|
161
|
+
)
|
|
162
|
+
self.kv_cache_observer = (
|
|
163
|
+
kv_cache_observer
|
|
164
|
+
if kv_cache_observer in KVCACHE_OBSERVERS_CLASS
|
|
165
|
+
else self.kv_cache_observer
|
|
166
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def filter_func(name):
|
|
22
|
+
pattern = re.compile(
|
|
23
|
+
r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
|
|
24
|
+
)
|
|
25
|
+
return pattern.match(name) is not None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DiTHook:
|
|
29
|
+
def __init__(self, model, use_transformer_engine=False):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
model(nn.Moudle, required): the model to be quant
|
|
33
|
+
"""
|
|
34
|
+
self.model = model
|
|
35
|
+
self.input_activation = {}
|
|
36
|
+
self.output_activation = {}
|
|
37
|
+
self.input_activation_cnt = {}
|
|
38
|
+
self.output_activation_cnt = {}
|
|
39
|
+
self.use_transformer_engine = use_transformer_engine
|
|
40
|
+
self._apply_hook()
|
|
41
|
+
|
|
42
|
+
def _apply_hook(self):
|
|
43
|
+
self._forward_hook_list = []
|
|
44
|
+
for name, sub_layer in self.model.named_modules():
|
|
45
|
+
if filter_func(name):
|
|
46
|
+
continue
|
|
47
|
+
instance_list = nn.Linear
|
|
48
|
+
if isinstance(sub_layer, instance_list):
|
|
49
|
+
forward_pre_hook_handle = sub_layer.register_forward_hook(
|
|
50
|
+
self._forward_pre_hook
|
|
51
|
+
)
|
|
52
|
+
self._forward_hook_list.append(forward_pre_hook_handle)
|
|
53
|
+
|
|
54
|
+
def _forward_pre_hook(self, layer, input, output):
|
|
55
|
+
layer_name = ""
|
|
56
|
+
for name, module in self.model.named_modules():
|
|
57
|
+
if filter_func(name):
|
|
58
|
+
continue
|
|
59
|
+
if module == layer:
|
|
60
|
+
layer_name = name
|
|
61
|
+
break
|
|
62
|
+
x = (
|
|
63
|
+
output[0].detach().cpu()
|
|
64
|
+
if isinstance(output, tuple)
|
|
65
|
+
else output.detach().cpu()
|
|
66
|
+
)
|
|
67
|
+
self.output_activation[layer_name] = (
|
|
68
|
+
self.output_activation.get(layer_name, torch.zeros(x.shape).to(x.dtype)) + x
|
|
69
|
+
)
|
|
70
|
+
self.output_activation_cnt[layer_name] = (
|
|
71
|
+
self.output_activation_cnt.get(layer_name, 0) + 1
|
|
72
|
+
)
|
|
73
|
+
y = (
|
|
74
|
+
input[0].detach().cpu()
|
|
75
|
+
if isinstance(input, tuple)
|
|
76
|
+
else input.detach().cpu()
|
|
77
|
+
)
|
|
78
|
+
self.input_activation[layer_name] = (
|
|
79
|
+
self.input_activation.get(layer_name, torch.zeros(y.shape).to(y.dtype)) + y
|
|
80
|
+
)
|
|
81
|
+
self.input_activation_cnt[layer_name] = (
|
|
82
|
+
self.input_activation_cnt.get(layer_name, 0) + 1
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def remove_hook(self):
|
|
86
|
+
for hook in self._forward_hook_list:
|
|
87
|
+
hook.remove()
|
|
88
|
+
self._forward_hook_list = []
|
|
89
|
+
|
|
90
|
+
def clean_acitvation_list(self):
|
|
91
|
+
self.input_activation = {}
|
|
92
|
+
self.output_activation = {}
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright 2025 Tencent Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from ..observers import ParentObserver, PTQObserver
|
|
20
|
+
from .quant_func import get_fp_maxval, get_fp_search_maxval
|
|
21
|
+
|
|
22
|
+
__all__ = ["PTQHook", "DiTHook"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PTQHook:
|
|
26
|
+
def __init__(self, model):
|
|
27
|
+
self.quant_model = model
|
|
28
|
+
self._forward_hook_list = []
|
|
29
|
+
# {name: layer}
|
|
30
|
+
self.quant_layers_dict = {}
|
|
31
|
+
# {layer: observer}
|
|
32
|
+
self.observer_dict = {}
|
|
33
|
+
self.kv_names = []
|
|
34
|
+
|
|
35
|
+
def apply_hook(self):
|
|
36
|
+
self.quant_layers_dict = self.quant_model.get_observer_layers()
|
|
37
|
+
self.kv_names = self.quant_model.get_kvcache_observer_layers_names(
|
|
38
|
+
self.quant_layers_dict.keys()
|
|
39
|
+
)
|
|
40
|
+
act_observer = self.quant_model.quant_algo_dict["act_observer"]
|
|
41
|
+
weight_observer = self.quant_model.quant_algo_dict["weight_observer"]
|
|
42
|
+
kv_cache_observer = self.quant_model.quant_algo_dict["kv_cache_observer"]
|
|
43
|
+
|
|
44
|
+
quant_parent_dict = self.quant_model.get_parent_dict(self.quant_layers_dict)
|
|
45
|
+
parent_observers = {
|
|
46
|
+
v: ParentObserver() for v in set(quant_parent_dict.values())
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# apply observers
|
|
50
|
+
for name, sub_layer in self.quant_layers_dict.items():
|
|
51
|
+
extra_kwargs = (
|
|
52
|
+
{"parent_observer": parent_observers[quant_parent_dict[name]]}
|
|
53
|
+
if name in quant_parent_dict
|
|
54
|
+
else {}
|
|
55
|
+
)
|
|
56
|
+
observer = PTQObserver(
|
|
57
|
+
sub_layer,
|
|
58
|
+
act_observer,
|
|
59
|
+
weight_observer,
|
|
60
|
+
kv_cache_observer if name in self.kv_names else None,
|
|
61
|
+
self.quant_model.quant_algo_dict,
|
|
62
|
+
**extra_kwargs
|
|
63
|
+
)
|
|
64
|
+
forward_hook_handle = sub_layer.register_forward_hook(self._forward_hook)
|
|
65
|
+
self.observer_dict[sub_layer] = observer
|
|
66
|
+
self._forward_hook_list.append(forward_hook_handle)
|
|
67
|
+
|
|
68
|
+
def apply_smooth_hook(self, smooth_mapping_layers, smooth_observer):
|
|
69
|
+
for smooth_layer, _ in smooth_mapping_layers.values():
|
|
70
|
+
observer = PTQObserver(
|
|
71
|
+
smooth_layer,
|
|
72
|
+
act_observer=None,
|
|
73
|
+
weight_observer=None,
|
|
74
|
+
kv_cache_observer=None,
|
|
75
|
+
quant_algo_dict=self.quant_model.quant_algo_dict,
|
|
76
|
+
smooth_act_observer=smooth_observer,
|
|
77
|
+
)
|
|
78
|
+
forward_hook_handle = smooth_layer.register_forward_hook(self._forward_hook)
|
|
79
|
+
self.observer_dict[smooth_layer] = observer
|
|
80
|
+
self._forward_hook_list.append(forward_hook_handle)
|
|
81
|
+
|
|
82
|
+
def _forward_hook(self, layer, input, output):
|
|
83
|
+
x = input[0].clone() if isinstance(input, tuple) else input.clone()
|
|
84
|
+
y = output[0].clone() if isinstance(output, tuple) else output.clone()
|
|
85
|
+
if hasattr(self.quant_model, "apply_layer_norm_list"):
|
|
86
|
+
if layer in self.quant_model.apply_layer_norm_list:
|
|
87
|
+
x = self.quant_model.apply_layer_norm(layer, x)
|
|
88
|
+
self.observer_dict[layer](x, y)
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
def remove_hook(self):
|
|
92
|
+
for hook in self._forward_hook_list:
|
|
93
|
+
hook.remove()
|
|
94
|
+
self._forward_hook_list = []
|
|
95
|
+
|
|
96
|
+
def post_process(self):
|
|
97
|
+
maxval = get_fp_maxval(bits=8)
|
|
98
|
+
if self.quant_model.quant_algo_dict["w_quant_algo"] == "fp8":
|
|
99
|
+
for k, v in self.quant_model.weight_scales_dict.items():
|
|
100
|
+
self.quant_model.weight_scales_dict[k] = v / maxval.type(v.dtype)
|
|
101
|
+
if self.quant_model.quant_algo_dict["a_quant_algo"] == "fp8":
|
|
102
|
+
for name, sub_layer in self.quant_layers_dict.items():
|
|
103
|
+
if sub_layer in self.observer_dict:
|
|
104
|
+
if name in self.quant_model.act_scales_dict.keys():
|
|
105
|
+
act_dtype = self.quant_model.act_scales_dict[name].dtype
|
|
106
|
+
if "Search" in str(self.observer_dict[sub_layer]):
|
|
107
|
+
tmp_maxval = get_fp_search_maxval(
|
|
108
|
+
self.observer_dict[sub_layer].sampled_input
|
|
109
|
+
)
|
|
110
|
+
self.quant_model.act_scales_dict[name] = (
|
|
111
|
+
self.quant_model.act_scales_dict[name]
|
|
112
|
+
/ tmp_maxval.type(act_dtype)
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
self.quant_model.act_scales_dict[name] = (
|
|
116
|
+
self.quant_model.act_scales_dict[name]
|
|
117
|
+
/ maxval.type(act_dtype)
|
|
118
|
+
)
|
|
119
|
+
if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8":
|
|
120
|
+
for k, v in self.quant_model.kv_cache_scales_dict.items():
|
|
121
|
+
self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _filter_func(name):
|
|
125
|
+
pattern = re.compile(
|
|
126
|
+
r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
|
|
127
|
+
)
|
|
128
|
+
return pattern.match(name) is not None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class DiTHook:
|
|
132
|
+
def __init__(self, model):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
model(nn.Moudle, required): the model to be quant
|
|
136
|
+
"""
|
|
137
|
+
self.model = model
|
|
138
|
+
self.input_activation = []
|
|
139
|
+
self.output_activation = []
|
|
140
|
+
|
|
141
|
+
self._apply_hook()
|
|
142
|
+
|
|
143
|
+
def _apply_hook(self):
|
|
144
|
+
self._forward_hook_list = []
|
|
145
|
+
for name, sub_layer in self.model.named_modules():
|
|
146
|
+
if _filter_func(name):
|
|
147
|
+
continue
|
|
148
|
+
if isinstance(sub_layer, (torch.nn.Conv2d, torch.nn.Linear)):
|
|
149
|
+
if "blocks" in name:
|
|
150
|
+
# handle
|
|
151
|
+
forward_pre_hook_handle = sub_layer.register_forward_hook(
|
|
152
|
+
self._forward_pre_hook
|
|
153
|
+
)
|
|
154
|
+
self._forward_hook_list.append(forward_pre_hook_handle)
|
|
155
|
+
|
|
156
|
+
def _forward_pre_hook(self, layer, input, output):
|
|
157
|
+
layer_name = ""
|
|
158
|
+
for name, module in self.model.named_modules():
|
|
159
|
+
if _filter_func(name):
|
|
160
|
+
continue
|
|
161
|
+
if module == layer:
|
|
162
|
+
layer_name = name
|
|
163
|
+
break
|
|
164
|
+
x = (
|
|
165
|
+
output[0].detach().cpu()
|
|
166
|
+
if isinstance(output, tuple)
|
|
167
|
+
else output.detach().cpu()
|
|
168
|
+
)
|
|
169
|
+
self.output_activation.append((layer_name, x))
|
|
170
|
+
y = (
|
|
171
|
+
input[0].detach().cpu()
|
|
172
|
+
if isinstance(input, tuple)
|
|
173
|
+
else input.detach().cpu()
|
|
174
|
+
)
|
|
175
|
+
self.input_activation.append((layer_name, y))
|
|
176
|
+
|
|
177
|
+
def remove_hook(self):
|
|
178
|
+
for hook in self._forward_hook_list:
|
|
179
|
+
hook.remove()
|
|
180
|
+
self._forward_hook_list = []
|
|
181
|
+
|
|
182
|
+
def clean_acitvation_list(self):
|
|
183
|
+
self.input_activation = []
|
|
184
|
+
self.output_activation = []
|