tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -1,225 +0,0 @@
1
- # Copyright (c) 2024 Intel Corporation
2
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import types
17
- from typing import Any, Dict, List, Optional
18
-
19
- import torch
20
-
21
- from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
22
- from tico.experimental.quantization.algorithm.gptq.utils import (
23
- find_layers,
24
- gather_single_batch_from_dict,
25
- gather_single_batch_from_list,
26
- )
27
- from tico.experimental.quantization.config import BaseConfig, GPTQConfig
28
- from tico.experimental.quantization.quantizer import BaseQuantizer
29
-
30
-
31
- class GPTQQuantizer(BaseQuantizer):
32
- """
33
- Quantizer for applying the GPTQ algorithm (typically for weight quantization)
34
- """
35
-
36
- def __init__(self, config: BaseConfig):
37
- super().__init__(config)
38
-
39
- self.cache_args: List[Any] = []
40
- self.cache_kwargs: Dict[str, Any] = {"batch_num": 0}
41
-
42
- @torch.no_grad()
43
- def prepare(
44
- self,
45
- model: torch.nn.Module,
46
- args: Optional[Any] = None,
47
- kwargs: Optional[Dict[str, Any]] = None,
48
- ):
49
- """
50
- Overrides the forward method of the first LLaMA layer (layer 0) to capture the
51
- input required for calibration.
52
-
53
- This method modifies the original forward pass of LLaMA layer 0 so that the
54
- inputs used during inference are intercepted and recorded. These captured inputs
55
- are then utilized to calibrate the quantization parameters for the GPTQ.
56
-
57
- Parameters:
58
- model: The target PyTorch model.
59
- args: Positional example inputs required for capturing graph.
60
- kwargs: Keyword example inputs required for capturing graph.
61
-
62
- Returns:
63
- The model prepared for GPTQ quantization.
64
- """
65
- if args is None and kwargs is None:
66
- raise RuntimeError(
67
- "Either args or kwargs must be provided for captruing graph."
68
- )
69
- # Define a function to capture input activations and associated parameters.
70
- def forward(layer, *args, **kwargs):
71
- self.cache_kwargs["batch_num"] += 1
72
- for idx, item in enumerate(args):
73
- if (idx + 1) > len(self.cache_args):
74
- self.cache_args.append([])
75
- self.cache_args[idx].append(item)
76
- for arg in kwargs:
77
- if self.cache_kwargs.get(arg, None) is None:
78
- self.cache_kwargs[arg] = []
79
- self.cache_kwargs[arg].append(kwargs[arg])
80
- # Raise an error to interrupt the forward pass after capturing data.
81
- raise ValueError
82
-
83
- # Replace the first layer with defined function to capture calibration data.
84
- if hasattr(model, "model"):
85
- assert hasattr(model.model, "layers")
86
- assert isinstance(model.model.layers, torch.nn.ModuleList)
87
- layer_forward_cache = model.model.layers[0].forward
88
- model.model.layers[0].forward = types.MethodType(
89
- forward, model.model.layers[0]
90
- )
91
- else:
92
- assert hasattr(model, "forward")
93
- layer_forward_cache = model.forward
94
- model.forward = types.MethodType(forward, model.forward)
95
-
96
- model_forward_cache = model.forward
97
- # Replace model's forward to avoid ValueError
98
- def model_forward(model, *args, **kwargs):
99
- nonlocal model_forward_cache
100
- try:
101
- model_forward_cache(*args, **kwargs)
102
- except ValueError:
103
- pass
104
-
105
- model.forward = types.MethodType(model_forward, model)
106
- kwargs = kwargs or {}
107
- model(*args, **kwargs) # type: ignore[misc]
108
-
109
- # Recover original forward
110
- model.forward = model_forward_cache
111
- if hasattr(model, "model"):
112
- assert hasattr(model.model, "layers")
113
- assert isinstance(model.model.layers, torch.nn.ModuleList)
114
- model.model.layers[0].forward = layer_forward_cache
115
- else:
116
- model.forward = layer_forward_cache
117
-
118
- return model
119
-
120
- @torch.no_grad()
121
- def convert(self, model):
122
- """
123
- Convert the prepared model to its GPTQ quantized version.
124
-
125
- Applies the GPTQ quantization on weights based on the collected statistics.
126
-
127
- Parameters:
128
- model: The prepared PyTorch model.
129
-
130
- Returns:
131
- The quantized model.
132
- """
133
- gptq_conf = self.config
134
- assert isinstance(gptq_conf, GPTQConfig)
135
-
136
- # Save the original cache setting and disable caching during calibration/inference.
137
- if hasattr(model, "config"):
138
- use_cache = model.config.use_cache
139
- model.config.use_cache = False
140
-
141
- quantizers = {}
142
- if hasattr(model, "model"):
143
- target_layers = model.model.layers
144
- else:
145
- target_layers = [model]
146
- for l_idx, layer in enumerate(target_layers):
147
- # Identify quantizable submodules within the layer.
148
- full = find_layers(layer)
149
-
150
- sequential = [list(full.keys())]
151
- for names in sequential:
152
- subset = {n: full[n] for n in names}
153
-
154
- gptq: Dict[str, GPTQ] = {}
155
- for name in subset:
156
- gptq[name] = GPTQ(subset[name])
157
- gptq[name].quantizer.configure(
158
- 8, perchannel=True, sym=False, mse=False
159
- )
160
- # Define a hook to collect input/output batches for quantizer calibration.
161
- def add_batch(name):
162
- def tmp(_, inp, out):
163
- gptq[name].add_batch(inp[0].data, out.data)
164
-
165
- return tmp
166
-
167
- handles = []
168
- for name in subset:
169
- handles.append(subset[name].register_forward_hook(add_batch(name)))
170
- # Run the current layer on the stored calibration inputs to capture activation stats.
171
- batch_num = self.cache_kwargs.pop("batch_num")
172
- for batch_idx in range(batch_num):
173
- cache_args_batch = gather_single_batch_from_list(
174
- self.cache_args, batch_idx
175
- )
176
- cache_kwargs_batch = gather_single_batch_from_dict(
177
- self.cache_kwargs, batch_idx
178
- )
179
- layer(*cache_args_batch, **cache_kwargs_batch)[0]
180
- self.cache_kwargs["batch_num"] = batch_num
181
- for h in handles:
182
- h.remove()
183
- # Quantize each submodule using the collected calibration data.
184
- for name in subset:
185
- if gptq_conf.verbose:
186
- print(l_idx, name)
187
- print("Quantizing ...")
188
- gptq[name].fasterquant(
189
- percdamp=0.01,
190
- groupsize=-1,
191
- actorder=True,
192
- static_groups=False,
193
- verbose=gptq_conf.verbose,
194
- )
195
- quantizers["model.layers.%d.%s" % (l_idx, name)] = gptq[
196
- name
197
- ].quantizer
198
- gptq[name].free()
199
- """
200
- Execute the quantized layer with the calibration inputs to obtain ouptuts
201
- that will serve as inputs for the next layer.
202
-
203
- This ensures that the quantization effects are correctly propagated to subsequent
204
- layers.
205
- """
206
- batch_num = self.cache_kwargs.pop("batch_num")
207
- for batch_idx in range(batch_num):
208
- cache_args_batch = gather_single_batch_from_list(
209
- self.cache_args, batch_idx
210
- )
211
- cache_kwargs_batch = gather_single_batch_from_dict(
212
- self.cache_kwargs, batch_idx
213
- )
214
- outs = layer(*cache_args_batch, **cache_kwargs_batch)[0]
215
- # Update inputs for next iteration.
216
- self.cache_args[0][batch_idx] = outs
217
- self.cache_kwargs["batch_num"] = batch_num
218
-
219
- if torch.cuda.is_available():
220
- torch.cuda.empty_cache()
221
- # Restore the original cache configuration.
222
- if hasattr(model, "config"):
223
- model.config.use_cache = use_cache
224
-
225
- return model
@@ -1,164 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Dict, List, Optional
16
-
17
- import torch
18
-
19
-
20
- @torch.no_grad()
21
- def smooth_weights(
22
- front_module: torch.nn.Module,
23
- back_modules: torch.nn.Module | List[torch.nn.Module],
24
- activation_max: torch.Tensor,
25
- alpha: float,
26
- ):
27
- """
28
- Applies SmoothQuant-style smoothing to the weights and biases of two
29
- connected modules using activation maximum values.
30
-
31
- NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
32
-
33
- Parameters
34
- -----------
35
- front_module
36
- The front module whose weights and biases will be adjusted.
37
- back_modules
38
- A list of back modules whose weights and biases will be adjusted.
39
- activation_max
40
- A tensor of channel-wise maximum activation values for the front module.
41
- alpha
42
- The smoothing factor that determines the scaling for weight adjustments.
43
-
44
- Raises
45
- -------
46
- AttributeError
47
- If `front_module` or any module in `back_modules` does not have `weight` attributes.
48
- ValueError
49
- If the shape of tensors in `activation_max` does not match the number of channels
50
- in `front_module`'s weight.
51
- NoteImplementedError
52
- If `front_module` or any module in `back_modules` is of an unsupported type.
53
- """
54
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
55
-
56
- if not isinstance(back_modules, list):
57
- back_modules = [back_modules]
58
-
59
- # Check attributes
60
- if not hasattr(front_module, "weight"):
61
- raise AttributeError(
62
- f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
63
- )
64
- for back_m in back_modules:
65
- if not hasattr(back_m, "weight"):
66
- raise AttributeError(
67
- f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
68
- )
69
- # Check shapes
70
- if isinstance(front_module, LlamaRMSNorm):
71
- front_numel = front_module.weight.numel()
72
- else:
73
- raise NotImplementedError(
74
- f"Unsupported module type: {type(front_module).__name__}"
75
- )
76
- for back_m in back_modules:
77
- if isinstance(back_m, torch.nn.Linear):
78
- back_numel = back_m.in_features
79
- else:
80
- raise NotImplementedError(
81
- f"Unsupported module type: {type(front_module).__name__}"
82
- )
83
-
84
- if front_numel != back_numel or back_numel != activation_max.numel():
85
- raise ValueError(
86
- f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
87
- )
88
-
89
- # Compute scales
90
- device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
91
- activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
92
- weight_scales = torch.cat(
93
- [back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
94
- dim=0,
95
- )
96
- weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
97
- scales = (
98
- (activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
99
- .clamp(min=1e-5)
100
- .to(device) # type: ignore[arg-type]
101
- .to(dtype) # type: ignore[arg-type]
102
- )
103
-
104
- # Smooth
105
- front_module.weight.div_(scales)
106
- if hasattr(front_module, "bias"):
107
- front_module.bias.div_(scales)
108
-
109
- for back_m in back_modules:
110
- back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
111
-
112
-
113
- @torch.no_grad()
114
- def apply_smoothing(
115
- model: torch.nn.Module,
116
- activation_max: Dict[str, torch.Tensor],
117
- alpha: float = 0.5,
118
- custom_alpha_map: Optional[Dict[str, float]] = None,
119
- ):
120
- """
121
- Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
122
-
123
- Parameters
124
- -----------
125
- model
126
- A torch module whose weights will be smoothed.
127
- activation_max
128
- The channel-wise maximum activation values for the model.
129
- alpha
130
- The default smoothing factor to apply across all modules.
131
- custom_alpha_map
132
- A dictionary mapping layer/module names to custom alpha values.
133
- Layers specified in this dictionary will use the corresponding alpha
134
- value instead of the default.
135
- """
136
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
137
-
138
- for name, module in model.named_modules():
139
- alpha_to_apply = alpha
140
- if custom_alpha_map and name in custom_alpha_map:
141
- alpha_to_apply = custom_alpha_map[name]
142
- if alpha_to_apply > 1.0:
143
- raise RuntimeError(
144
- f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
145
- )
146
- # SmoothQuant is applied before capturing the graph. Therefore, it needs to know
147
- # specific module information.
148
- # TODO Suport more modules.
149
- if isinstance(module, LlamaDecoderLayer):
150
- attn_ln = module.input_layernorm
151
- qkv = [
152
- module.self_attn.q_proj,
153
- module.self_attn.k_proj,
154
- module.self_attn.v_proj,
155
- ]
156
-
157
- qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
158
- smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
159
-
160
- ffn_ln = module.post_attention_layernorm
161
- fcs = [module.mlp.gate_proj, module.mlp.up_proj]
162
- fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
163
-
164
- smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
@@ -1,109 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Callable, Dict, List
16
-
17
- import numpy as np
18
- import torch
19
-
20
-
21
- def compute_peir(base: torch.Tensor, target: torch.Tensor):
22
- """
23
- Calculate the Peak Error to Interval Ratio (PEIR) between two tensors.
24
-
25
- This function computes the PEIR between two tensors using the formula:
26
- PEIR = max(abs(tensor1 - tensor2)) / (max(tensor1) - min(tensor2))
27
- """
28
- assert base.shape == target.shape, f"shape mismatch: {base.shape} != {target.shape}"
29
- base_tensor = base.numpy()
30
- target_tensor = target.numpy()
31
- assert (
32
- base_tensor.dtype == np.float32 and target_tensor.dtype == np.float32
33
- ), f"dtype should be float32: base({base_tensor.dtype}), target({target_tensor.dtype})"
34
-
35
- base_tensor = base_tensor.reshape(-1)
36
- target_tensor = target_tensor.reshape(-1)
37
-
38
- assert (
39
- base_tensor.shape == target_tensor.shape
40
- ), f"Shape mismatch: {base_tensor.shape} != {target_tensor.shape}"
41
-
42
- peak_error = np.max(np.absolute(target_tensor - base_tensor))
43
- interval = np.max(base_tensor) - np.min(base_tensor)
44
- peir = peak_error / interval # pylint: disable=invalid-name
45
-
46
- min_value = min([base_tensor.min(), target_tensor.min()])
47
- max_value = max([base_tensor.max(), target_tensor.max()])
48
-
49
- interval = max_value - min_value
50
- interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
51
-
52
- return peir
53
-
54
-
55
- class MetricCalculator:
56
- """
57
- Compute metrics including both built-in and custom metrics.
58
-
59
- metrics
60
- A list of metric names for comparison.
61
- custom_metrics
62
- A dictionary of metric names and corresponding callable functions for comparison.
63
- Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
64
- """
65
-
66
- builtin_metrics = {
67
- "peir": compute_peir,
68
- }
69
-
70
- def __init__(
71
- self,
72
- metrics: List[str] = list(),
73
- custom_metrics: Dict[str, Callable] = dict(),
74
- ):
75
- self.metrics: Dict[str, Callable] = dict()
76
-
77
- for m in metrics:
78
- if m in self.builtin_metrics:
79
- self.metrics[m] = self.builtin_metrics[m]
80
- else:
81
- raise RuntimeError(f"Invalid metric: {m}")
82
-
83
- duplicates = set(self.metrics).intersection(custom_metrics.keys())
84
- if len(duplicates) != 0:
85
- raise RuntimeError(f"There are duplicate metrics: {duplicates}")
86
-
87
- self.metrics = self.metrics | custom_metrics
88
-
89
- def compute(
90
- self, output1: List[torch.Tensor], output2: List[torch.Tensor]
91
- ) -> Dict[str, List[Any]]:
92
- """
93
- Compute both built-in metrics (if provided) and custom metrics.
94
-
95
- Returns
96
- --------
97
- Dict[str, Any]
98
- A dictionary with metric names and their computed values.
99
- """
100
- results: Dict[str, List[Any]] = dict()
101
-
102
- # Compute built-in metrics
103
- if self.metrics is not None:
104
- for m in self.metrics:
105
- results[m] = list()
106
- for out1, out2 in zip(output1, output2):
107
- results[m].append(self.builtin_metrics[m](out1, out2))
108
-
109
- return results