angelslim 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. angelslim/__init__.py +15 -0
  2. angelslim/compressor/__init__.py +15 -0
  3. angelslim/compressor/compressor_factory.py +83 -0
  4. angelslim/compressor/distill/__init__.py +13 -0
  5. angelslim/compressor/quant/__init__.py +18 -0
  6. angelslim/compressor/quant/core/__init__.py +25 -0
  7. angelslim/compressor/quant/core/config.py +166 -0
  8. angelslim/compressor/quant/core/dit_hook.py +92 -0
  9. angelslim/compressor/quant/core/hook.py +184 -0
  10. angelslim/compressor/quant/core/metrics.py +88 -0
  11. angelslim/compressor/quant/core/packing_utils.py +117 -0
  12. angelslim/compressor/quant/core/quant_func.py +316 -0
  13. angelslim/compressor/quant/core/sample_func.py +45 -0
  14. angelslim/compressor/quant/core/save.py +211 -0
  15. angelslim/compressor/quant/modules/__init__.py +25 -0
  16. angelslim/compressor/quant/modules/awq/__init__.py +13 -0
  17. angelslim/compressor/quant/modules/awq/auto_clip.py +290 -0
  18. angelslim/compressor/quant/modules/awq/auto_scale.py +222 -0
  19. angelslim/compressor/quant/modules/awq/awq.py +334 -0
  20. angelslim/compressor/quant/modules/awq/search.py +137 -0
  21. angelslim/compressor/quant/modules/catcher.py +33 -0
  22. angelslim/compressor/quant/modules/fp8/__init__.py +13 -0
  23. angelslim/compressor/quant/modules/fp8/fp8.py +142 -0
  24. angelslim/compressor/quant/modules/gptq/__init__.py +13 -0
  25. angelslim/compressor/quant/modules/gptq/gptq.py +319 -0
  26. angelslim/compressor/quant/modules/gptq/gptq_module.py +219 -0
  27. angelslim/compressor/quant/modules/helper_layer.py +628 -0
  28. angelslim/compressor/quant/modules/int8/__init__.py +13 -0
  29. angelslim/compressor/quant/modules/int8/int8.py +142 -0
  30. angelslim/compressor/quant/modules/smooth/__init__.py +13 -0
  31. angelslim/compressor/quant/modules/smooth/smooth.py +88 -0
  32. angelslim/compressor/quant/observers/__init__.py +23 -0
  33. angelslim/compressor/quant/observers/abs_max_activation.py +216 -0
  34. angelslim/compressor/quant/observers/abs_max_weight.py +77 -0
  35. angelslim/compressor/quant/observers/base_observer.py +115 -0
  36. angelslim/compressor/quant/observers/ema_activation.py +63 -0
  37. angelslim/compressor/quant/observers/groupwise_weight.py +85 -0
  38. angelslim/compressor/quant/observers/hist_activation.py +244 -0
  39. angelslim/compressor/quant/observers/observer.py +70 -0
  40. angelslim/compressor/quant/ptq.py +188 -0
  41. angelslim/compressor/sparsity/__init__.py +13 -0
  42. angelslim/compressor/speculative_decoding/__init__.py +13 -0
  43. angelslim/data/__init__.py +9 -0
  44. angelslim/data/base_dataset.py +54 -0
  45. angelslim/data/dataloader.py +95 -0
  46. angelslim/data/multimodal_dataset.py +153 -0
  47. angelslim/data/text_dataset.py +162 -0
  48. angelslim/engine.py +183 -0
  49. angelslim/models/__init__.py +17 -0
  50. angelslim/models/base_model.py +309 -0
  51. angelslim/models/llm/__init__.py +19 -0
  52. angelslim/models/llm/deepseek.py +79 -0
  53. angelslim/models/llm/hunyuan_dense.py +71 -0
  54. angelslim/models/llm/hunyuan_moe.py +94 -0
  55. angelslim/models/llm/llama.py +68 -0
  56. angelslim/models/llm/qwen.py +92 -0
  57. angelslim/models/model_factory.py +69 -0
  58. angelslim/models/vlm/__init__.py +16 -0
  59. angelslim/models/vlm/qwen_vl.py +78 -0
  60. angelslim/utils/__init__.py +25 -0
  61. angelslim/utils/config_parser.py +389 -0
  62. angelslim/utils/default_compress_config.py +128 -0
  63. angelslim/utils/utils.py +134 -0
  64. angelslim-0.1.0.dist-info/METADATA +45 -0
  65. angelslim-0.1.0.dist-info/RECORD +68 -0
  66. angelslim-0.1.0.dist-info/WHEEL +5 -0
  67. angelslim-0.1.0.dist-info/licenses/LICENSE +87 -0
  68. angelslim-0.1.0.dist-info/top_level.txt +1 -0
angelslim/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .engine import Engine # noqa: F401
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .compressor_factory import CompressorFactory # noqa: F401
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, Optional, Type, Union
16
+
17
+ from ..utils import print_info
18
+
19
+
20
+ class CompressorFactory:
21
+ """
22
+ Factory class for model compression methods with flexible registration.
23
+ Supports both explicit name registration and direct class name registration.
24
+ """
25
+
26
+ _compress_methods: Dict[str, Type[Any]] = {}
27
+
28
+ @classmethod
29
+ def register(cls, name: Optional[Union[str, Callable]] = None) -> Callable:
30
+ """Decorator to register compression methods. Supports two usage patterns:
31
+ 1. @CompressorFactory.register("explicit_name")
32
+ 2. @CompressorFactory.register (uses class name as key)
33
+ """
34
+
35
+ # Handler for direct class registration (@CompressorFactory.register)
36
+ def register_class(compress_cls: Type[Any]) -> Type[Any]:
37
+ """Register a class using its own name as the key"""
38
+ key = compress_cls.__name__
39
+ if key in cls._compress_methods:
40
+ print_info(
41
+ f"Compression method '{key}' already exists, will be overwritten."
42
+ )
43
+ cls._compress_methods[key] = compress_cls
44
+ return compress_cls
45
+
46
+ # Handler for named registration (@CompressorFactory.register("name"))
47
+ def register_with_name(key: str) -> Callable[[Type[Any]], Type[Any]]:
48
+ """Decorator that registers a class with a custom key"""
49
+
50
+ def decorator(compress_cls: Type[Any]) -> Type[Any]:
51
+ if key in cls._compress_methods:
52
+ print_info(f"register '{key}' already exists, will be overwritten.")
53
+ cls._compress_methods[key] = compress_cls
54
+ return compress_cls
55
+
56
+ return decorator
57
+
58
+ # Determine registration type based on input
59
+ if name is None:
60
+ # Case 1: Direct class registration (@CompressorFactory.register)
61
+ return register_class
62
+ elif isinstance(name, str):
63
+ # Case 2: Explicit name registration (@CompressorFactory.register("name"))
64
+ return register_with_name(name)
65
+ elif callable(name):
66
+ # Case 3: Direct class registration (called without parentheses)
67
+ return register_class(name)
68
+ else:
69
+ raise TypeError("Invalid argument type for registration")
70
+
71
+ @classmethod
72
+ def create(cls, name: str, model: Any, slim_config: Any) -> Any:
73
+ """Create compressor instance"""
74
+ if name not in cls._compress_methods:
75
+ available = list(cls._compress_methods.keys())
76
+ raise ValueError(
77
+ f"Compression method '{name}' not registered. Available: {available}"
78
+ )
79
+ return cls._compress_methods[name](model, slim_config)
80
+
81
+ @classmethod
82
+ def get_available_compressor(cls) -> list:
83
+ return list(cls._compress_methods.keys())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,18 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .core import * # noqa: F401 F403
16
+ from .modules import * # noqa: F401 F403
17
+ from .observers import * # noqa: F401 F403
18
+ from .ptq import PTQ # noqa: F401
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .config import * # noqa: F401 F403
16
+ from .hook import DiTHook, PTQHook # noqa: F401
17
+ from .metrics import mse_loss, snr_loss # noqa: F401
18
+ from .packing_utils import dequantize_gemm # noqa: F401
19
+ from .quant_func import * # noqa: F401 F403
20
+ from .sample_func import EMASampler, MultiStepSampler # noqa: F401
21
+ from .save import PTQPTMSave # noqa: F401
22
+ from .save import PTQSaveVllmHF # noqa: F401
23
+ from .save import PTQTorchSave # noqa: F401
24
+ from .save import PTQvLLMSaveHF # noqa: F401
25
+ from .save import PTQVLMSaveVllmHF # noqa: F401
@@ -0,0 +1,166 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+
17
+ from ..observers import (
18
+ AbsMaxChannelWiseWeightObserver,
19
+ AbsMaxGroupWiseWeightObserver,
20
+ AbsmaxPerchannelObserver,
21
+ AbsmaxPertensorObserver,
22
+ )
23
+
24
+ ACT_OBSERVERS_CLASS = {
25
+ "per-tensor": AbsmaxPertensorObserver,
26
+ "per-channel": AbsmaxPerchannelObserver,
27
+ }
28
+ WEIGHT_OBSERVERS_CLASS = {
29
+ "per-tensor": AbsmaxPertensorObserver,
30
+ "per-channel": AbsMaxChannelWiseWeightObserver,
31
+ "per-group": AbsMaxGroupWiseWeightObserver,
32
+ }
33
+
34
+ KVCACHE_OBSERVERS_CLASS = {"per-channel": AbsmaxPerchannelObserver}
35
+
36
+
37
+ class QuantConfig:
38
+ r"""
39
+ Configure how to quantize a model or a part of the model. It will map each layer to
40
+ an instance of observers by the settings.
41
+
42
+ Args:
43
+ config: The quant config.
44
+ Examples:
45
+ .. code-block:: python
46
+
47
+ >>> from slim.quant import QuantConfig
48
+ >>> q_config = QuantConfig(yaml_config)
49
+ """
50
+
51
+ def __init__(self, config, global_config=None):
52
+ # quant_algo change
53
+ self.act_observer = None
54
+ self.weight_observer = None
55
+ self.kv_cache_observer = None
56
+
57
+ quantization_args = config.quantization
58
+ self.quant_algo = quantization_args.name
59
+ self.quant_bit = quantization_args.bits
60
+ self.max_seq_length = global_config.max_seq_length
61
+ self.quant_helpers = quantization_args.quant_helpers
62
+ act_quant_method = quantization_args.quant_method.get("activation", None)
63
+ weight_quant_method = quantization_args.quant_method["weight"]
64
+
65
+ if "fp8" in self.quant_algo:
66
+ is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
67
+ assert (
68
+ is_dynamic or act_quant_method is not None
69
+ ), "[OpenSlim][Error] fp8_static need act_quant_method"
70
+ self.act_observer = (
71
+ ACT_OBSERVERS_CLASS[act_quant_method]
72
+ if "static" in is_dynamic
73
+ else None
74
+ )
75
+ self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
76
+ self.kv_cache_observer = None
77
+ self.quant_algo_info = {
78
+ "w": f"fp8_{weight_quant_method}",
79
+ "ignore_layers": quantization_args.ignore_layers,
80
+ }
81
+ if act_quant_method is not None:
82
+ self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}"
83
+ self.hidden_size = global_config.hidden_size
84
+ self.model_arch_type = global_config.model_arch_type
85
+ self.low_memory = config.quantization.low_memory
86
+ elif "int8" in self.quant_algo:
87
+ is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
88
+ assert (
89
+ is_dynamic or act_quant_method is not None
90
+ ), "[OpenSlim][Error] int8_static need act_quant_method"
91
+ self.act_observer = (
92
+ ACT_OBSERVERS_CLASS[act_quant_method]
93
+ if "static" in is_dynamic
94
+ else None
95
+ )
96
+ self.weight_observer = WEIGHT_OBSERVERS_CLASS[weight_quant_method]
97
+ self.kv_cache_observer = None
98
+ self.quant_algo_info = {
99
+ "w": f"int8_{weight_quant_method}",
100
+ "ignore_layers": quantization_args.ignore_layers,
101
+ }
102
+ if act_quant_method is not None:
103
+ self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}"
104
+ self.hidden_size = global_config.hidden_size
105
+ self.model_arch_type = global_config.model_arch_type
106
+ self.low_memory = config.quantization.low_memory
107
+ elif "int4_awq" in self.quant_algo:
108
+ self.act_observer = None
109
+ self.weight_observer = None
110
+ self.kv_cache_observer = None
111
+ group_size = (
112
+ 128
113
+ if quantization_args.quant_method["group_size"] == -1
114
+ else quantization_args.quant_method["group_size"]
115
+ )
116
+ self.quant_algo_info = {
117
+ "zero_point": quantization_args.quant_method["zero_point"],
118
+ "group_size": int(group_size),
119
+ "mse_range": quantization_args.quant_method["mse_range"],
120
+ }
121
+ self.hidden_size = global_config.hidden_size
122
+ self.model_arch_type = global_config.model_arch_type
123
+ elif "int4_gptq" in self.quant_algo:
124
+ self.act_observer = None
125
+ self.weight_observer = None
126
+ self.kv_cache_observer = None
127
+ group_size = (
128
+ 128
129
+ if quantization_args.quant_method["group_size"] == -1
130
+ else quantization_args.quant_method["group_size"]
131
+ )
132
+ self.quant_algo_info = {
133
+ "group_size": group_size,
134
+ "ignore_layers": quantization_args.ignore_layers,
135
+ }
136
+ self.hidden_size = global_config.hidden_size
137
+
138
+ if "smooth" in self.quant_helpers:
139
+ self.smooth_alpha = quantization_args.smooth_alpha
140
+ self.smooth_observer = ACT_OBSERVERS_CLASS["per-channel"]
141
+ self.custom_observe_layers_names = "default"
142
+
143
+ def custom_observe_layers(
144
+ self,
145
+ names: List,
146
+ act_observer="default",
147
+ weight_observer="default",
148
+ kv_cache_observer="default",
149
+ ):
150
+ """
151
+ name supports fuzzy search.
152
+ """
153
+ self.custom_observe_layers_names = names
154
+ self.act_observer = (
155
+ act_observer if act_observer in ACT_OBSERVERS_CLASS else self.act_observer
156
+ )
157
+ self.weight_observer = (
158
+ weight_observer
159
+ if weight_observer in WEIGHT_OBSERVERS_CLASS
160
+ else self.weight_observer
161
+ )
162
+ self.kv_cache_observer = (
163
+ kv_cache_observer
164
+ if kv_cache_observer in KVCACHE_OBSERVERS_CLASS
165
+ else self.kv_cache_observer
166
+ )
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ def filter_func(name):
22
+ pattern = re.compile(
23
+ r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
24
+ )
25
+ return pattern.match(name) is not None
26
+
27
+
28
+ class DiTHook:
29
+ def __init__(self, model, use_transformer_engine=False):
30
+ """
31
+ Args:
32
+ model(nn.Moudle, required): the model to be quant
33
+ """
34
+ self.model = model
35
+ self.input_activation = {}
36
+ self.output_activation = {}
37
+ self.input_activation_cnt = {}
38
+ self.output_activation_cnt = {}
39
+ self.use_transformer_engine = use_transformer_engine
40
+ self._apply_hook()
41
+
42
+ def _apply_hook(self):
43
+ self._forward_hook_list = []
44
+ for name, sub_layer in self.model.named_modules():
45
+ if filter_func(name):
46
+ continue
47
+ instance_list = nn.Linear
48
+ if isinstance(sub_layer, instance_list):
49
+ forward_pre_hook_handle = sub_layer.register_forward_hook(
50
+ self._forward_pre_hook
51
+ )
52
+ self._forward_hook_list.append(forward_pre_hook_handle)
53
+
54
+ def _forward_pre_hook(self, layer, input, output):
55
+ layer_name = ""
56
+ for name, module in self.model.named_modules():
57
+ if filter_func(name):
58
+ continue
59
+ if module == layer:
60
+ layer_name = name
61
+ break
62
+ x = (
63
+ output[0].detach().cpu()
64
+ if isinstance(output, tuple)
65
+ else output.detach().cpu()
66
+ )
67
+ self.output_activation[layer_name] = (
68
+ self.output_activation.get(layer_name, torch.zeros(x.shape).to(x.dtype)) + x
69
+ )
70
+ self.output_activation_cnt[layer_name] = (
71
+ self.output_activation_cnt.get(layer_name, 0) + 1
72
+ )
73
+ y = (
74
+ input[0].detach().cpu()
75
+ if isinstance(input, tuple)
76
+ else input.detach().cpu()
77
+ )
78
+ self.input_activation[layer_name] = (
79
+ self.input_activation.get(layer_name, torch.zeros(y.shape).to(y.dtype)) + y
80
+ )
81
+ self.input_activation_cnt[layer_name] = (
82
+ self.input_activation_cnt.get(layer_name, 0) + 1
83
+ )
84
+
85
+ def remove_hook(self):
86
+ for hook in self._forward_hook_list:
87
+ hook.remove()
88
+ self._forward_hook_list = []
89
+
90
+ def clean_acitvation_list(self):
91
+ self.input_activation = {}
92
+ self.output_activation = {}
@@ -0,0 +1,184 @@
1
+ # Copyright 2025 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ import torch
18
+
19
+ from ..observers import ParentObserver, PTQObserver
20
+ from .quant_func import get_fp_maxval, get_fp_search_maxval
21
+
22
+ __all__ = ["PTQHook", "DiTHook"]
23
+
24
+
25
+ class PTQHook:
26
+ def __init__(self, model):
27
+ self.quant_model = model
28
+ self._forward_hook_list = []
29
+ # {name: layer}
30
+ self.quant_layers_dict = {}
31
+ # {layer: observer}
32
+ self.observer_dict = {}
33
+ self.kv_names = []
34
+
35
+ def apply_hook(self):
36
+ self.quant_layers_dict = self.quant_model.get_observer_layers()
37
+ self.kv_names = self.quant_model.get_kvcache_observer_layers_names(
38
+ self.quant_layers_dict.keys()
39
+ )
40
+ act_observer = self.quant_model.quant_algo_dict["act_observer"]
41
+ weight_observer = self.quant_model.quant_algo_dict["weight_observer"]
42
+ kv_cache_observer = self.quant_model.quant_algo_dict["kv_cache_observer"]
43
+
44
+ quant_parent_dict = self.quant_model.get_parent_dict(self.quant_layers_dict)
45
+ parent_observers = {
46
+ v: ParentObserver() for v in set(quant_parent_dict.values())
47
+ }
48
+
49
+ # apply observers
50
+ for name, sub_layer in self.quant_layers_dict.items():
51
+ extra_kwargs = (
52
+ {"parent_observer": parent_observers[quant_parent_dict[name]]}
53
+ if name in quant_parent_dict
54
+ else {}
55
+ )
56
+ observer = PTQObserver(
57
+ sub_layer,
58
+ act_observer,
59
+ weight_observer,
60
+ kv_cache_observer if name in self.kv_names else None,
61
+ self.quant_model.quant_algo_dict,
62
+ **extra_kwargs
63
+ )
64
+ forward_hook_handle = sub_layer.register_forward_hook(self._forward_hook)
65
+ self.observer_dict[sub_layer] = observer
66
+ self._forward_hook_list.append(forward_hook_handle)
67
+
68
+ def apply_smooth_hook(self, smooth_mapping_layers, smooth_observer):
69
+ for smooth_layer, _ in smooth_mapping_layers.values():
70
+ observer = PTQObserver(
71
+ smooth_layer,
72
+ act_observer=None,
73
+ weight_observer=None,
74
+ kv_cache_observer=None,
75
+ quant_algo_dict=self.quant_model.quant_algo_dict,
76
+ smooth_act_observer=smooth_observer,
77
+ )
78
+ forward_hook_handle = smooth_layer.register_forward_hook(self._forward_hook)
79
+ self.observer_dict[smooth_layer] = observer
80
+ self._forward_hook_list.append(forward_hook_handle)
81
+
82
+ def _forward_hook(self, layer, input, output):
83
+ x = input[0].clone() if isinstance(input, tuple) else input.clone()
84
+ y = output[0].clone() if isinstance(output, tuple) else output.clone()
85
+ if hasattr(self.quant_model, "apply_layer_norm_list"):
86
+ if layer in self.quant_model.apply_layer_norm_list:
87
+ x = self.quant_model.apply_layer_norm(layer, x)
88
+ self.observer_dict[layer](x, y)
89
+ return output
90
+
91
+ def remove_hook(self):
92
+ for hook in self._forward_hook_list:
93
+ hook.remove()
94
+ self._forward_hook_list = []
95
+
96
+ def post_process(self):
97
+ maxval = get_fp_maxval(bits=8)
98
+ if self.quant_model.quant_algo_dict["w_quant_algo"] == "fp8":
99
+ for k, v in self.quant_model.weight_scales_dict.items():
100
+ self.quant_model.weight_scales_dict[k] = v / maxval.type(v.dtype)
101
+ if self.quant_model.quant_algo_dict["a_quant_algo"] == "fp8":
102
+ for name, sub_layer in self.quant_layers_dict.items():
103
+ if sub_layer in self.observer_dict:
104
+ if name in self.quant_model.act_scales_dict.keys():
105
+ act_dtype = self.quant_model.act_scales_dict[name].dtype
106
+ if "Search" in str(self.observer_dict[sub_layer]):
107
+ tmp_maxval = get_fp_search_maxval(
108
+ self.observer_dict[sub_layer].sampled_input
109
+ )
110
+ self.quant_model.act_scales_dict[name] = (
111
+ self.quant_model.act_scales_dict[name]
112
+ / tmp_maxval.type(act_dtype)
113
+ )
114
+ else:
115
+ self.quant_model.act_scales_dict[name] = (
116
+ self.quant_model.act_scales_dict[name]
117
+ / maxval.type(act_dtype)
118
+ )
119
+ if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8":
120
+ for k, v in self.quant_model.kv_cache_scales_dict.items():
121
+ self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)
122
+
123
+
124
+ def _filter_func(name):
125
+ pattern = re.compile(
126
+ r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
127
+ )
128
+ return pattern.match(name) is not None
129
+
130
+
131
+ class DiTHook:
132
+ def __init__(self, model):
133
+ """
134
+ Args:
135
+ model(nn.Moudle, required): the model to be quant
136
+ """
137
+ self.model = model
138
+ self.input_activation = []
139
+ self.output_activation = []
140
+
141
+ self._apply_hook()
142
+
143
+ def _apply_hook(self):
144
+ self._forward_hook_list = []
145
+ for name, sub_layer in self.model.named_modules():
146
+ if _filter_func(name):
147
+ continue
148
+ if isinstance(sub_layer, (torch.nn.Conv2d, torch.nn.Linear)):
149
+ if "blocks" in name:
150
+ # handle
151
+ forward_pre_hook_handle = sub_layer.register_forward_hook(
152
+ self._forward_pre_hook
153
+ )
154
+ self._forward_hook_list.append(forward_pre_hook_handle)
155
+
156
+ def _forward_pre_hook(self, layer, input, output):
157
+ layer_name = ""
158
+ for name, module in self.model.named_modules():
159
+ if _filter_func(name):
160
+ continue
161
+ if module == layer:
162
+ layer_name = name
163
+ break
164
+ x = (
165
+ output[0].detach().cpu()
166
+ if isinstance(output, tuple)
167
+ else output.detach().cpu()
168
+ )
169
+ self.output_activation.append((layer_name, x))
170
+ y = (
171
+ input[0].detach().cpu()
172
+ if isinstance(input, tuple)
173
+ else input.detach().cpu()
174
+ )
175
+ self.input_activation.append((layer_name, y))
176
+
177
+ def remove_hook(self):
178
+ for hook in self._forward_hook_list:
179
+ hook.remove()
180
+ self._forward_hook_list = []
181
+
182
+ def clean_acitvation_list(self):
183
+ self.input_activation = []
184
+ self.output_activation = []