tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -1,225 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024 Intel Corporation
|
|
2
|
-
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import types
|
|
17
|
-
from typing import Any, Dict, List, Optional
|
|
18
|
-
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
|
-
from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
|
|
22
|
-
from tico.experimental.quantization.algorithm.gptq.utils import (
|
|
23
|
-
find_layers,
|
|
24
|
-
gather_single_batch_from_dict,
|
|
25
|
-
gather_single_batch_from_list,
|
|
26
|
-
)
|
|
27
|
-
from tico.experimental.quantization.config import BaseConfig, GPTQConfig
|
|
28
|
-
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class GPTQQuantizer(BaseQuantizer):
|
|
32
|
-
"""
|
|
33
|
-
Quantizer for applying the GPTQ algorithm (typically for weight quantization)
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self, config: BaseConfig):
|
|
37
|
-
super().__init__(config)
|
|
38
|
-
|
|
39
|
-
self.cache_args: List[Any] = []
|
|
40
|
-
self.cache_kwargs: Dict[str, Any] = {"batch_num": 0}
|
|
41
|
-
|
|
42
|
-
@torch.no_grad()
|
|
43
|
-
def prepare(
|
|
44
|
-
self,
|
|
45
|
-
model: torch.nn.Module,
|
|
46
|
-
args: Optional[Any] = None,
|
|
47
|
-
kwargs: Optional[Dict[str, Any]] = None,
|
|
48
|
-
):
|
|
49
|
-
"""
|
|
50
|
-
Overrides the forward method of the first LLaMA layer (layer 0) to capture the
|
|
51
|
-
input required for calibration.
|
|
52
|
-
|
|
53
|
-
This method modifies the original forward pass of LLaMA layer 0 so that the
|
|
54
|
-
inputs used during inference are intercepted and recorded. These captured inputs
|
|
55
|
-
are then utilized to calibrate the quantization parameters for the GPTQ.
|
|
56
|
-
|
|
57
|
-
Parameters:
|
|
58
|
-
model: The target PyTorch model.
|
|
59
|
-
args: Positional example inputs required for capturing graph.
|
|
60
|
-
kwargs: Keyword example inputs required for capturing graph.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
The model prepared for GPTQ quantization.
|
|
64
|
-
"""
|
|
65
|
-
if args is None and kwargs is None:
|
|
66
|
-
raise RuntimeError(
|
|
67
|
-
"Either args or kwargs must be provided for captruing graph."
|
|
68
|
-
)
|
|
69
|
-
# Define a function to capture input activations and associated parameters.
|
|
70
|
-
def forward(layer, *args, **kwargs):
|
|
71
|
-
self.cache_kwargs["batch_num"] += 1
|
|
72
|
-
for idx, item in enumerate(args):
|
|
73
|
-
if (idx + 1) > len(self.cache_args):
|
|
74
|
-
self.cache_args.append([])
|
|
75
|
-
self.cache_args[idx].append(item)
|
|
76
|
-
for arg in kwargs:
|
|
77
|
-
if self.cache_kwargs.get(arg, None) is None:
|
|
78
|
-
self.cache_kwargs[arg] = []
|
|
79
|
-
self.cache_kwargs[arg].append(kwargs[arg])
|
|
80
|
-
# Raise an error to interrupt the forward pass after capturing data.
|
|
81
|
-
raise ValueError
|
|
82
|
-
|
|
83
|
-
# Replace the first layer with defined function to capture calibration data.
|
|
84
|
-
if hasattr(model, "model"):
|
|
85
|
-
assert hasattr(model.model, "layers")
|
|
86
|
-
assert isinstance(model.model.layers, torch.nn.ModuleList)
|
|
87
|
-
layer_forward_cache = model.model.layers[0].forward
|
|
88
|
-
model.model.layers[0].forward = types.MethodType(
|
|
89
|
-
forward, model.model.layers[0]
|
|
90
|
-
)
|
|
91
|
-
else:
|
|
92
|
-
assert hasattr(model, "forward")
|
|
93
|
-
layer_forward_cache = model.forward
|
|
94
|
-
model.forward = types.MethodType(forward, model.forward)
|
|
95
|
-
|
|
96
|
-
model_forward_cache = model.forward
|
|
97
|
-
# Replace model's forward to avoid ValueError
|
|
98
|
-
def model_forward(model, *args, **kwargs):
|
|
99
|
-
nonlocal model_forward_cache
|
|
100
|
-
try:
|
|
101
|
-
model_forward_cache(*args, **kwargs)
|
|
102
|
-
except ValueError:
|
|
103
|
-
pass
|
|
104
|
-
|
|
105
|
-
model.forward = types.MethodType(model_forward, model)
|
|
106
|
-
kwargs = kwargs or {}
|
|
107
|
-
model(*args, **kwargs) # type: ignore[misc]
|
|
108
|
-
|
|
109
|
-
# Recover original forward
|
|
110
|
-
model.forward = model_forward_cache
|
|
111
|
-
if hasattr(model, "model"):
|
|
112
|
-
assert hasattr(model.model, "layers")
|
|
113
|
-
assert isinstance(model.model.layers, torch.nn.ModuleList)
|
|
114
|
-
model.model.layers[0].forward = layer_forward_cache
|
|
115
|
-
else:
|
|
116
|
-
model.forward = layer_forward_cache
|
|
117
|
-
|
|
118
|
-
return model
|
|
119
|
-
|
|
120
|
-
@torch.no_grad()
|
|
121
|
-
def convert(self, model):
|
|
122
|
-
"""
|
|
123
|
-
Convert the prepared model to its GPTQ quantized version.
|
|
124
|
-
|
|
125
|
-
Applies the GPTQ quantization on weights based on the collected statistics.
|
|
126
|
-
|
|
127
|
-
Parameters:
|
|
128
|
-
model: The prepared PyTorch model.
|
|
129
|
-
|
|
130
|
-
Returns:
|
|
131
|
-
The quantized model.
|
|
132
|
-
"""
|
|
133
|
-
gptq_conf = self.config
|
|
134
|
-
assert isinstance(gptq_conf, GPTQConfig)
|
|
135
|
-
|
|
136
|
-
# Save the original cache setting and disable caching during calibration/inference.
|
|
137
|
-
if hasattr(model, "config"):
|
|
138
|
-
use_cache = model.config.use_cache
|
|
139
|
-
model.config.use_cache = False
|
|
140
|
-
|
|
141
|
-
quantizers = {}
|
|
142
|
-
if hasattr(model, "model"):
|
|
143
|
-
target_layers = model.model.layers
|
|
144
|
-
else:
|
|
145
|
-
target_layers = [model]
|
|
146
|
-
for l_idx, layer in enumerate(target_layers):
|
|
147
|
-
# Identify quantizable submodules within the layer.
|
|
148
|
-
full = find_layers(layer)
|
|
149
|
-
|
|
150
|
-
sequential = [list(full.keys())]
|
|
151
|
-
for names in sequential:
|
|
152
|
-
subset = {n: full[n] for n in names}
|
|
153
|
-
|
|
154
|
-
gptq: Dict[str, GPTQ] = {}
|
|
155
|
-
for name in subset:
|
|
156
|
-
gptq[name] = GPTQ(subset[name])
|
|
157
|
-
gptq[name].quantizer.configure(
|
|
158
|
-
8, perchannel=True, sym=False, mse=False
|
|
159
|
-
)
|
|
160
|
-
# Define a hook to collect input/output batches for quantizer calibration.
|
|
161
|
-
def add_batch(name):
|
|
162
|
-
def tmp(_, inp, out):
|
|
163
|
-
gptq[name].add_batch(inp[0].data, out.data)
|
|
164
|
-
|
|
165
|
-
return tmp
|
|
166
|
-
|
|
167
|
-
handles = []
|
|
168
|
-
for name in subset:
|
|
169
|
-
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
170
|
-
# Run the current layer on the stored calibration inputs to capture activation stats.
|
|
171
|
-
batch_num = self.cache_kwargs.pop("batch_num")
|
|
172
|
-
for batch_idx in range(batch_num):
|
|
173
|
-
cache_args_batch = gather_single_batch_from_list(
|
|
174
|
-
self.cache_args, batch_idx
|
|
175
|
-
)
|
|
176
|
-
cache_kwargs_batch = gather_single_batch_from_dict(
|
|
177
|
-
self.cache_kwargs, batch_idx
|
|
178
|
-
)
|
|
179
|
-
layer(*cache_args_batch, **cache_kwargs_batch)[0]
|
|
180
|
-
self.cache_kwargs["batch_num"] = batch_num
|
|
181
|
-
for h in handles:
|
|
182
|
-
h.remove()
|
|
183
|
-
# Quantize each submodule using the collected calibration data.
|
|
184
|
-
for name in subset:
|
|
185
|
-
if gptq_conf.verbose:
|
|
186
|
-
print(l_idx, name)
|
|
187
|
-
print("Quantizing ...")
|
|
188
|
-
gptq[name].fasterquant(
|
|
189
|
-
percdamp=0.01,
|
|
190
|
-
groupsize=-1,
|
|
191
|
-
actorder=True,
|
|
192
|
-
static_groups=False,
|
|
193
|
-
verbose=gptq_conf.verbose,
|
|
194
|
-
)
|
|
195
|
-
quantizers["model.layers.%d.%s" % (l_idx, name)] = gptq[
|
|
196
|
-
name
|
|
197
|
-
].quantizer
|
|
198
|
-
gptq[name].free()
|
|
199
|
-
"""
|
|
200
|
-
Execute the quantized layer with the calibration inputs to obtain ouptuts
|
|
201
|
-
that will serve as inputs for the next layer.
|
|
202
|
-
|
|
203
|
-
This ensures that the quantization effects are correctly propagated to subsequent
|
|
204
|
-
layers.
|
|
205
|
-
"""
|
|
206
|
-
batch_num = self.cache_kwargs.pop("batch_num")
|
|
207
|
-
for batch_idx in range(batch_num):
|
|
208
|
-
cache_args_batch = gather_single_batch_from_list(
|
|
209
|
-
self.cache_args, batch_idx
|
|
210
|
-
)
|
|
211
|
-
cache_kwargs_batch = gather_single_batch_from_dict(
|
|
212
|
-
self.cache_kwargs, batch_idx
|
|
213
|
-
)
|
|
214
|
-
outs = layer(*cache_args_batch, **cache_kwargs_batch)[0]
|
|
215
|
-
# Update inputs for next iteration.
|
|
216
|
-
self.cache_args[0][batch_idx] = outs
|
|
217
|
-
self.cache_kwargs["batch_num"] = batch_num
|
|
218
|
-
|
|
219
|
-
if torch.cuda.is_available():
|
|
220
|
-
torch.cuda.empty_cache()
|
|
221
|
-
# Restore the original cache configuration.
|
|
222
|
-
if hasattr(model, "config"):
|
|
223
|
-
model.config.use_cache = use_cache
|
|
224
|
-
|
|
225
|
-
return model
|
|
@@ -1,164 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import Any, Dict, List, Optional
|
|
16
|
-
|
|
17
|
-
import torch
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@torch.no_grad()
|
|
21
|
-
def smooth_weights(
|
|
22
|
-
front_module: torch.nn.Module,
|
|
23
|
-
back_modules: torch.nn.Module | List[torch.nn.Module],
|
|
24
|
-
activation_max: torch.Tensor,
|
|
25
|
-
alpha: float,
|
|
26
|
-
):
|
|
27
|
-
"""
|
|
28
|
-
Applies SmoothQuant-style smoothing to the weights and biases of two
|
|
29
|
-
connected modules using activation maximum values.
|
|
30
|
-
|
|
31
|
-
NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
|
|
32
|
-
|
|
33
|
-
Parameters
|
|
34
|
-
-----------
|
|
35
|
-
front_module
|
|
36
|
-
The front module whose weights and biases will be adjusted.
|
|
37
|
-
back_modules
|
|
38
|
-
A list of back modules whose weights and biases will be adjusted.
|
|
39
|
-
activation_max
|
|
40
|
-
A tensor of channel-wise maximum activation values for the front module.
|
|
41
|
-
alpha
|
|
42
|
-
The smoothing factor that determines the scaling for weight adjustments.
|
|
43
|
-
|
|
44
|
-
Raises
|
|
45
|
-
-------
|
|
46
|
-
AttributeError
|
|
47
|
-
If `front_module` or any module in `back_modules` does not have `weight` attributes.
|
|
48
|
-
ValueError
|
|
49
|
-
If the shape of tensors in `activation_max` does not match the number of channels
|
|
50
|
-
in `front_module`'s weight.
|
|
51
|
-
NoteImplementedError
|
|
52
|
-
If `front_module` or any module in `back_modules` is of an unsupported type.
|
|
53
|
-
"""
|
|
54
|
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
55
|
-
|
|
56
|
-
if not isinstance(back_modules, list):
|
|
57
|
-
back_modules = [back_modules]
|
|
58
|
-
|
|
59
|
-
# Check attributes
|
|
60
|
-
if not hasattr(front_module, "weight"):
|
|
61
|
-
raise AttributeError(
|
|
62
|
-
f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
|
|
63
|
-
)
|
|
64
|
-
for back_m in back_modules:
|
|
65
|
-
if not hasattr(back_m, "weight"):
|
|
66
|
-
raise AttributeError(
|
|
67
|
-
f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
|
|
68
|
-
)
|
|
69
|
-
# Check shapes
|
|
70
|
-
if isinstance(front_module, LlamaRMSNorm):
|
|
71
|
-
front_numel = front_module.weight.numel()
|
|
72
|
-
else:
|
|
73
|
-
raise NotImplementedError(
|
|
74
|
-
f"Unsupported module type: {type(front_module).__name__}"
|
|
75
|
-
)
|
|
76
|
-
for back_m in back_modules:
|
|
77
|
-
if isinstance(back_m, torch.nn.Linear):
|
|
78
|
-
back_numel = back_m.in_features
|
|
79
|
-
else:
|
|
80
|
-
raise NotImplementedError(
|
|
81
|
-
f"Unsupported module type: {type(front_module).__name__}"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
if front_numel != back_numel or back_numel != activation_max.numel():
|
|
85
|
-
raise ValueError(
|
|
86
|
-
f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Compute scales
|
|
90
|
-
device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
|
|
91
|
-
activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
|
|
92
|
-
weight_scales = torch.cat(
|
|
93
|
-
[back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
|
|
94
|
-
dim=0,
|
|
95
|
-
)
|
|
96
|
-
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
|
97
|
-
scales = (
|
|
98
|
-
(activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
|
|
99
|
-
.clamp(min=1e-5)
|
|
100
|
-
.to(device) # type: ignore[arg-type]
|
|
101
|
-
.to(dtype) # type: ignore[arg-type]
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
# Smooth
|
|
105
|
-
front_module.weight.div_(scales)
|
|
106
|
-
if hasattr(front_module, "bias"):
|
|
107
|
-
front_module.bias.div_(scales)
|
|
108
|
-
|
|
109
|
-
for back_m in back_modules:
|
|
110
|
-
back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@torch.no_grad()
|
|
114
|
-
def apply_smoothing(
|
|
115
|
-
model: torch.nn.Module,
|
|
116
|
-
activation_max: Dict[str, torch.Tensor],
|
|
117
|
-
alpha: float = 0.5,
|
|
118
|
-
custom_alpha_map: Optional[Dict[str, float]] = None,
|
|
119
|
-
):
|
|
120
|
-
"""
|
|
121
|
-
Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
|
|
122
|
-
|
|
123
|
-
Parameters
|
|
124
|
-
-----------
|
|
125
|
-
model
|
|
126
|
-
A torch module whose weights will be smoothed.
|
|
127
|
-
activation_max
|
|
128
|
-
The channel-wise maximum activation values for the model.
|
|
129
|
-
alpha
|
|
130
|
-
The default smoothing factor to apply across all modules.
|
|
131
|
-
custom_alpha_map
|
|
132
|
-
A dictionary mapping layer/module names to custom alpha values.
|
|
133
|
-
Layers specified in this dictionary will use the corresponding alpha
|
|
134
|
-
value instead of the default.
|
|
135
|
-
"""
|
|
136
|
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
137
|
-
|
|
138
|
-
for name, module in model.named_modules():
|
|
139
|
-
alpha_to_apply = alpha
|
|
140
|
-
if custom_alpha_map and name in custom_alpha_map:
|
|
141
|
-
alpha_to_apply = custom_alpha_map[name]
|
|
142
|
-
if alpha_to_apply > 1.0:
|
|
143
|
-
raise RuntimeError(
|
|
144
|
-
f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
|
|
145
|
-
)
|
|
146
|
-
# SmoothQuant is applied before capturing the graph. Therefore, it needs to know
|
|
147
|
-
# specific module information.
|
|
148
|
-
# TODO Suport more modules.
|
|
149
|
-
if isinstance(module, LlamaDecoderLayer):
|
|
150
|
-
attn_ln = module.input_layernorm
|
|
151
|
-
qkv = [
|
|
152
|
-
module.self_attn.q_proj,
|
|
153
|
-
module.self_attn.k_proj,
|
|
154
|
-
module.self_attn.v_proj,
|
|
155
|
-
]
|
|
156
|
-
|
|
157
|
-
qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
|
|
158
|
-
smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
|
|
159
|
-
|
|
160
|
-
ffn_ln = module.post_attention_layernorm
|
|
161
|
-
fcs = [module.mlp.gate_proj, module.mlp.up_proj]
|
|
162
|
-
fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
|
|
163
|
-
|
|
164
|
-
smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
|
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import Any, Callable, Dict, List
|
|
16
|
-
|
|
17
|
-
import numpy as np
|
|
18
|
-
import torch
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def compute_peir(base: torch.Tensor, target: torch.Tensor):
|
|
22
|
-
"""
|
|
23
|
-
Calculate the Peak Error to Interval Ratio (PEIR) between two tensors.
|
|
24
|
-
|
|
25
|
-
This function computes the PEIR between two tensors using the formula:
|
|
26
|
-
PEIR = max(abs(tensor1 - tensor2)) / (max(tensor1) - min(tensor2))
|
|
27
|
-
"""
|
|
28
|
-
assert base.shape == target.shape, f"shape mismatch: {base.shape} != {target.shape}"
|
|
29
|
-
base_tensor = base.numpy()
|
|
30
|
-
target_tensor = target.numpy()
|
|
31
|
-
assert (
|
|
32
|
-
base_tensor.dtype == np.float32 and target_tensor.dtype == np.float32
|
|
33
|
-
), f"dtype should be float32: base({base_tensor.dtype}), target({target_tensor.dtype})"
|
|
34
|
-
|
|
35
|
-
base_tensor = base_tensor.reshape(-1)
|
|
36
|
-
target_tensor = target_tensor.reshape(-1)
|
|
37
|
-
|
|
38
|
-
assert (
|
|
39
|
-
base_tensor.shape == target_tensor.shape
|
|
40
|
-
), f"Shape mismatch: {base_tensor.shape} != {target_tensor.shape}"
|
|
41
|
-
|
|
42
|
-
peak_error = np.max(np.absolute(target_tensor - base_tensor))
|
|
43
|
-
interval = np.max(base_tensor) - np.min(base_tensor)
|
|
44
|
-
peir = peak_error / interval # pylint: disable=invalid-name
|
|
45
|
-
|
|
46
|
-
min_value = min([base_tensor.min(), target_tensor.min()])
|
|
47
|
-
max_value = max([base_tensor.max(), target_tensor.max()])
|
|
48
|
-
|
|
49
|
-
interval = max_value - min_value
|
|
50
|
-
interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
|
|
51
|
-
|
|
52
|
-
return peir
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class MetricCalculator:
|
|
56
|
-
"""
|
|
57
|
-
Compute metrics including both built-in and custom metrics.
|
|
58
|
-
|
|
59
|
-
metrics
|
|
60
|
-
A list of metric names for comparison.
|
|
61
|
-
custom_metrics
|
|
62
|
-
A dictionary of metric names and corresponding callable functions for comparison.
|
|
63
|
-
Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
builtin_metrics = {
|
|
67
|
-
"peir": compute_peir,
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
metrics: List[str] = list(),
|
|
73
|
-
custom_metrics: Dict[str, Callable] = dict(),
|
|
74
|
-
):
|
|
75
|
-
self.metrics: Dict[str, Callable] = dict()
|
|
76
|
-
|
|
77
|
-
for m in metrics:
|
|
78
|
-
if m in self.builtin_metrics:
|
|
79
|
-
self.metrics[m] = self.builtin_metrics[m]
|
|
80
|
-
else:
|
|
81
|
-
raise RuntimeError(f"Invalid metric: {m}")
|
|
82
|
-
|
|
83
|
-
duplicates = set(self.metrics).intersection(custom_metrics.keys())
|
|
84
|
-
if len(duplicates) != 0:
|
|
85
|
-
raise RuntimeError(f"There are duplicate metrics: {duplicates}")
|
|
86
|
-
|
|
87
|
-
self.metrics = self.metrics | custom_metrics
|
|
88
|
-
|
|
89
|
-
def compute(
|
|
90
|
-
self, output1: List[torch.Tensor], output2: List[torch.Tensor]
|
|
91
|
-
) -> Dict[str, List[Any]]:
|
|
92
|
-
"""
|
|
93
|
-
Compute both built-in metrics (if provided) and custom metrics.
|
|
94
|
-
|
|
95
|
-
Returns
|
|
96
|
-
--------
|
|
97
|
-
Dict[str, Any]
|
|
98
|
-
A dictionary with metric names and their computed values.
|
|
99
|
-
"""
|
|
100
|
-
results: Dict[str, List[Any]] = dict()
|
|
101
|
-
|
|
102
|
-
# Compute built-in metrics
|
|
103
|
-
if self.metrics is not None:
|
|
104
|
-
for m in self.metrics:
|
|
105
|
-
results[m] = list()
|
|
106
|
-
for out1, out2 in zip(output1, output2):
|
|
107
|
-
results[m].append(self.builtin_metrics[m](out1, out2))
|
|
108
|
-
|
|
109
|
-
return results
|