tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) 2024 Intel Corporation
|
|
2
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import types
|
|
17
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from tqdm.auto import tqdm
|
|
21
|
+
|
|
22
|
+
from tico.quantization.algorithm.gptq.gptq import GPTQ
|
|
23
|
+
from tico.quantization.algorithm.gptq.utils import (
|
|
24
|
+
find_layers,
|
|
25
|
+
gather_single_batch_from_dict,
|
|
26
|
+
gather_single_batch_from_list,
|
|
27
|
+
)
|
|
28
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
29
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
30
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class StopForward(Exception):
|
|
34
|
+
"""Custom exception used to stop the forward pass after the first layer."""
|
|
35
|
+
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@register_quantizer(GPTQConfig)
|
|
40
|
+
class GPTQQuantizer(BaseQuantizer):
|
|
41
|
+
"""
|
|
42
|
+
Quantizer for applying the GPTQ algorithm (typically for weight quantization).
|
|
43
|
+
This implementation expects:
|
|
44
|
+
1) prepare(model, ...) to only attach hooks/Catchers and NOT run the model internally.
|
|
45
|
+
2) The user runs the model with arbitrary number of batches to collect calibration data.
|
|
46
|
+
3) convert(model) to consume the collected data and apply GPTQ.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config: GPTQConfig):
|
|
50
|
+
super().__init__(config)
|
|
51
|
+
|
|
52
|
+
# cache_args[i] -> list of the i-th positional argument for each batch
|
|
53
|
+
self.cache_args: List[List[Any]] = []
|
|
54
|
+
# cache_kwargs[k] -> list of the value for keyword k for each batch
|
|
55
|
+
self.cache_kwargs: Dict[str, List[Any]] = {}
|
|
56
|
+
self.num_batches: int = 0
|
|
57
|
+
|
|
58
|
+
# References to original forwards for restoration
|
|
59
|
+
self._orig_model_forward: Optional[Callable[..., Any]] = None
|
|
60
|
+
self._orig_layer_forward: Optional[Callable[..., Any]] = None
|
|
61
|
+
self._first_layer_ref: Optional[torch.nn.Module] = None
|
|
62
|
+
|
|
63
|
+
@torch.no_grad()
|
|
64
|
+
def prepare(
|
|
65
|
+
self,
|
|
66
|
+
model: torch.nn.Module,
|
|
67
|
+
args: Optional[Any] = None,
|
|
68
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Overrides the forward method of the first LLaMA layer (layer 0) to capture the
|
|
72
|
+
input required for calibration.
|
|
73
|
+
|
|
74
|
+
When the user calls `model(...)`, we intercept (and store) the inputs to that
|
|
75
|
+
layer, then raise an exception to stop the forward pass immediately. These
|
|
76
|
+
captured inputs are then utilized to calibrate the quantization parameters
|
|
77
|
+
for the GPTQ.
|
|
78
|
+
|
|
79
|
+
Parameters:
|
|
80
|
+
model (torch.nn.Module): The target PyTorch model
|
|
81
|
+
args (Any, optional): Unused (kept for API compatibility)
|
|
82
|
+
kwargs (Dict[str, Any], optional): Unused (kept for API compatibility)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
torch.nn.Module: The model with the catcher attached
|
|
86
|
+
"""
|
|
87
|
+
# Define the catcher to store inputs/kwargs and stop the execution
|
|
88
|
+
def forward(layer, *args, **kwargs):
|
|
89
|
+
"""
|
|
90
|
+
Stores this batch's inputs and kwargs, then raises StopForward to stop computation.
|
|
91
|
+
"""
|
|
92
|
+
# Store positional args
|
|
93
|
+
for idx, item in enumerate(args):
|
|
94
|
+
if (idx + 1) > len(self.cache_args):
|
|
95
|
+
self.cache_args.append([])
|
|
96
|
+
self.cache_args[idx].append(item)
|
|
97
|
+
# Store keyword args
|
|
98
|
+
for k, v in kwargs.items():
|
|
99
|
+
if self.cache_kwargs.get(k, None) is None:
|
|
100
|
+
self.cache_kwargs[k] = []
|
|
101
|
+
self.cache_kwargs[k].append(v)
|
|
102
|
+
|
|
103
|
+
self.num_batches += 1
|
|
104
|
+
raise StopForward # stop after the first layer
|
|
105
|
+
|
|
106
|
+
# Replace the first layer with defined function to capture calibration data.
|
|
107
|
+
if hasattr(model, "model"):
|
|
108
|
+
if hasattr(model.model, "layers") and isinstance(
|
|
109
|
+
model.model.layers, torch.nn.ModuleList
|
|
110
|
+
):
|
|
111
|
+
self._first_layer_ref = model.model.layers[0]
|
|
112
|
+
else:
|
|
113
|
+
raise RuntimeError(
|
|
114
|
+
"GPTQ Quantizer assumes the model has a nested structure like `model.model.layers`, commonly found in LLaMA and other Hugging Face transformer models."
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
# fallback if the model is not LLaMA-like; treat whole model as single layer
|
|
118
|
+
self._first_layer_ref = model
|
|
119
|
+
|
|
120
|
+
assert hasattr(self._first_layer_ref, "forward")
|
|
121
|
+
# Backup the original forward of the first layer
|
|
122
|
+
assert isinstance(self._first_layer_ref, torch.nn.Module)
|
|
123
|
+
self._orig_layer_forward = self._first_layer_ref.forward
|
|
124
|
+
self._first_layer_ref.forward = types.MethodType(forward, self._first_layer_ref)
|
|
125
|
+
|
|
126
|
+
def model_forward_wrapper(_model, *m_args, **m_kwargs):
|
|
127
|
+
"""
|
|
128
|
+
Wrapper to ignore StopForward exceptions so the user's training loop doesn't crash.
|
|
129
|
+
"""
|
|
130
|
+
try:
|
|
131
|
+
assert self._orig_model_forward is not None
|
|
132
|
+
return self._orig_model_forward(*m_args, **m_kwargs)
|
|
133
|
+
except StopForward:
|
|
134
|
+
# We stopped after the first layer; return None or dummy output if needed.
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
# Backup model.forward so we can suppress StopForward
|
|
138
|
+
self._orig_model_forward = model.forward
|
|
139
|
+
model.forward = types.MethodType(model_forward_wrapper, model)
|
|
140
|
+
|
|
141
|
+
return model
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def convert(self, model):
|
|
145
|
+
"""
|
|
146
|
+
Perform GPTQ quantization using cached first-layer inputs.
|
|
147
|
+
|
|
148
|
+
Steps:
|
|
149
|
+
1) Restore original forwards (no more catching).
|
|
150
|
+
2) Iterate through each Transformer layer sequentially:
|
|
151
|
+
a) For each layer, register forward hooks to collect (inp, out) stats for GPTQ.
|
|
152
|
+
b) Run the layer on cached inputs for all batches.
|
|
153
|
+
c) Apply GPTQ and update the weights.
|
|
154
|
+
d) Re-run the layer to produce outputs for the next layer; update cached inputs.
|
|
155
|
+
3) Restore model.config.use_cache if needed and clear internal caches.
|
|
156
|
+
|
|
157
|
+
Parameters:
|
|
158
|
+
model (torch.nn.Module): The prepared model.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
torch.nn.Module: Quantized model.
|
|
162
|
+
"""
|
|
163
|
+
# Restore original forwards (we no longer want to stop after first layer)
|
|
164
|
+
assert self._orig_model_forward is not None
|
|
165
|
+
model.forward = self._orig_model_forward
|
|
166
|
+
assert (
|
|
167
|
+
self._first_layer_ref is not None and self._orig_layer_forward is not None
|
|
168
|
+
)
|
|
169
|
+
self._first_layer_ref.forward = self._orig_layer_forward
|
|
170
|
+
|
|
171
|
+
gptq_conf = self.config
|
|
172
|
+
assert isinstance(gptq_conf, GPTQConfig)
|
|
173
|
+
# Disable use_cache during calibration
|
|
174
|
+
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
|
|
175
|
+
orig_use_cache = model.config.use_cache
|
|
176
|
+
model.config.use_cache = False
|
|
177
|
+
else:
|
|
178
|
+
orig_use_cache = None
|
|
179
|
+
|
|
180
|
+
# Identify layers
|
|
181
|
+
if hasattr(model, "model"):
|
|
182
|
+
target_layers = model.model.layers
|
|
183
|
+
else:
|
|
184
|
+
target_layers = [model]
|
|
185
|
+
|
|
186
|
+
quantizers: Dict[str, Any] = {}
|
|
187
|
+
for l_idx, layer in enumerate(
|
|
188
|
+
tqdm(
|
|
189
|
+
target_layers,
|
|
190
|
+
desc="Quantizing layers",
|
|
191
|
+
unit="layer",
|
|
192
|
+
disable=not gptq_conf.show_progress,
|
|
193
|
+
)
|
|
194
|
+
):
|
|
195
|
+
# 1) Identify quantizable submodules within the layer
|
|
196
|
+
full = find_layers(layer)
|
|
197
|
+
sequential = [list(full.keys())]
|
|
198
|
+
|
|
199
|
+
# 2) Set up GPTQ objects and gather stats
|
|
200
|
+
for names in sequential:
|
|
201
|
+
subset = {n: full[n] for n in names}
|
|
202
|
+
|
|
203
|
+
gptq: Dict[str, GPTQ] = {}
|
|
204
|
+
for name in subset:
|
|
205
|
+
gptq[name] = GPTQ(subset[name])
|
|
206
|
+
gptq[name].quantizer.configure(
|
|
207
|
+
bits=8, perchannel=True, sym=False, mse=False
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Hook to collect (inp, out) for GPTQ
|
|
211
|
+
def add_batch(name):
|
|
212
|
+
def _hook(_, inp, out):
|
|
213
|
+
gptq[name].add_batch(inp[0].data, out.data)
|
|
214
|
+
|
|
215
|
+
return _hook
|
|
216
|
+
|
|
217
|
+
handles = []
|
|
218
|
+
for name in subset:
|
|
219
|
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
220
|
+
|
|
221
|
+
# Run layer forward over all cached batches to build Hessian/statistics
|
|
222
|
+
batch_num = self.num_batches
|
|
223
|
+
for batch_idx in tqdm(
|
|
224
|
+
range(batch_num),
|
|
225
|
+
desc=f"[L{l_idx}] collecting",
|
|
226
|
+
leave=False,
|
|
227
|
+
unit="batch",
|
|
228
|
+
disable=not gptq_conf.show_progress,
|
|
229
|
+
):
|
|
230
|
+
cache_args_batch = gather_single_batch_from_list(
|
|
231
|
+
self.cache_args, batch_idx
|
|
232
|
+
)
|
|
233
|
+
cache_kwargs_batch = gather_single_batch_from_dict(
|
|
234
|
+
self.cache_kwargs, batch_idx
|
|
235
|
+
)
|
|
236
|
+
layer(*cache_args_batch, **cache_kwargs_batch)
|
|
237
|
+
|
|
238
|
+
# Remove handles
|
|
239
|
+
for h in handles:
|
|
240
|
+
h.remove()
|
|
241
|
+
|
|
242
|
+
# 3) Quantize each submodule
|
|
243
|
+
for name in subset:
|
|
244
|
+
if gptq_conf.verbose:
|
|
245
|
+
print(f"[Layer {l_idx}] {name} -> Quantizing ...")
|
|
246
|
+
gptq[name].fasterquant(
|
|
247
|
+
percdamp=0.01,
|
|
248
|
+
groupsize=-1,
|
|
249
|
+
actorder=True,
|
|
250
|
+
static_groups=False,
|
|
251
|
+
verbose=gptq_conf.verbose,
|
|
252
|
+
)
|
|
253
|
+
quantizers[f"model.layers.{l_idx}.{name}"] = gptq[name].quantizer
|
|
254
|
+
gptq[name].free()
|
|
255
|
+
|
|
256
|
+
# 4) After quantization, re-run the layer to produce outputs for the next layer
|
|
257
|
+
for batch_idx in tqdm(
|
|
258
|
+
range(batch_num),
|
|
259
|
+
desc=f"[L{l_idx}] re-forward",
|
|
260
|
+
leave=False,
|
|
261
|
+
unit="batch",
|
|
262
|
+
disable=not gptq_conf.show_progress,
|
|
263
|
+
):
|
|
264
|
+
cache_args_batch = gather_single_batch_from_list(
|
|
265
|
+
self.cache_args, batch_idx
|
|
266
|
+
)
|
|
267
|
+
cache_kwargs_batch = gather_single_batch_from_dict(
|
|
268
|
+
self.cache_kwargs, batch_idx
|
|
269
|
+
)
|
|
270
|
+
outs = layer(*cache_args_batch, **cache_kwargs_batch)
|
|
271
|
+
# LLaMA's decoder layer return type differs across Transformers versions:
|
|
272
|
+
# some return a tuple (hidden_states, ...), others return just a tensor.
|
|
273
|
+
# This line ensures we always take the first element when it's a tuple.
|
|
274
|
+
outs = outs[0] if isinstance(outs, tuple) else outs
|
|
275
|
+
# Update inputs for next iteration.
|
|
276
|
+
self.cache_args[0][batch_idx] = outs
|
|
277
|
+
|
|
278
|
+
if torch.cuda.is_available():
|
|
279
|
+
torch.cuda.empty_cache()
|
|
280
|
+
|
|
281
|
+
# Restore the original cache configuration.
|
|
282
|
+
if orig_use_cache is not None:
|
|
283
|
+
model.config.use_cache = orig_use_cache
|
|
284
|
+
|
|
285
|
+
# Clear caches to free memory
|
|
286
|
+
self.cache_args.clear()
|
|
287
|
+
self.cache_kwargs.clear()
|
|
288
|
+
self.num_batches = 0
|
|
289
|
+
|
|
290
|
+
model.quantizers = quantizers
|
|
291
|
+
|
|
292
|
+
return model
|
|
@@ -58,7 +58,7 @@ def gather_single_batch_from_list(data_list, idx):
|
|
|
58
58
|
Returns:
|
|
59
59
|
list: single batch.
|
|
60
60
|
"""
|
|
61
|
-
# obtain a set of
|
|
61
|
+
# obtain a set of positional input from cache
|
|
62
62
|
single_batch = []
|
|
63
63
|
for data_item in data_list:
|
|
64
64
|
single_batch.append(data_item[idx])
|
|
@@ -21,23 +21,16 @@ if TYPE_CHECKING:
|
|
|
21
21
|
import torch.fx
|
|
22
22
|
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
|
23
23
|
import torch
|
|
24
|
-
from torch.ao.quantization.observer import
|
|
25
|
-
MinMaxObserver,
|
|
26
|
-
MovingAverageMinMaxObserver,
|
|
27
|
-
MovingAveragePerChannelMinMaxObserver,
|
|
28
|
-
PerChannelMinMaxObserver,
|
|
29
|
-
)
|
|
24
|
+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
|
30
25
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
|
31
26
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
32
27
|
|
|
33
|
-
from tico.
|
|
34
|
-
import tico.
|
|
35
|
-
import tico.
|
|
36
|
-
import tico.
|
|
37
|
-
from tico.
|
|
38
|
-
|
|
39
|
-
)
|
|
40
|
-
from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
28
|
+
from tico.quantization.algorithm.pt2e.annotation.op import *
|
|
29
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
30
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
31
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
32
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
33
|
+
from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
41
34
|
convert_scalars_to_attrs,
|
|
42
35
|
)
|
|
43
36
|
|
tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py
RENAMED
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import AddTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import Conv2DArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import DivTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -12,19 +12,17 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Optional, TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import LinearArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MeanDimArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MulTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import Relu6Args
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import RsqrtArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import SubTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
22
|
-
QuantizationConfig,
|
|
23
|
-
)
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
24
22
|
|
|
25
23
|
AnnotatorType = Callable[
|
|
26
24
|
[
|
|
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
|
|
|
22
22
|
SharedQuantizationSpec,
|
|
23
23
|
)
|
|
24
24
|
|
|
25
|
-
import tico.
|
|
25
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
|
|
@@ -18,13 +18,16 @@ import torch
|
|
|
18
18
|
|
|
19
19
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.annotator import (
|
|
22
22
|
get_asymmetric_quantization_config,
|
|
23
23
|
PT2EAnnotator,
|
|
24
24
|
)
|
|
25
|
-
from tico.
|
|
25
|
+
from tico.quantization.config.pt2e import PT2EConfig
|
|
26
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
27
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
@register_quantizer(PT2EConfig)
|
|
28
31
|
class PT2EQuantizer(BaseQuantizer):
|
|
29
32
|
"""
|
|
30
33
|
Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
|
|
@@ -19,11 +19,8 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
|
21
21
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
22
|
-
from torch.utils import _pytree as pytree
|
|
23
22
|
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
23
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
24
|
|
|
28
25
|
|
|
29
26
|
def get_module_type_filter(tp: Callable):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import functools
|
|
16
|
-
from typing import Any, Dict, List
|
|
16
|
+
from typing import Any, Dict, List, Literal
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
@@ -21,18 +21,24 @@ import torch
|
|
|
21
21
|
class ChannelwiseMaxActsObserver:
|
|
22
22
|
"""
|
|
23
23
|
Observer to calcuate channelwise maximum activation
|
|
24
|
+
It supports collecting activations from either module inputs or outputs.
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
|
|
29
|
+
):
|
|
27
30
|
"""
|
|
28
31
|
model
|
|
29
32
|
A torch module whose activations are to be analyzed.
|
|
33
|
+
acts_from
|
|
34
|
+
Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
|
|
30
35
|
hooks
|
|
31
|
-
A list to store the hooks
|
|
36
|
+
A list to store the hooks registered to collect activation statistics.
|
|
32
37
|
max_acts
|
|
33
|
-
A dictionary to store the
|
|
38
|
+
A dictionary to store the per-channel maxima.
|
|
34
39
|
"""
|
|
35
40
|
self.model = model
|
|
41
|
+
self.acts_from: Literal["input", "output"] = acts_from
|
|
36
42
|
self.hooks: List[Any] = []
|
|
37
43
|
self.max_acts: Dict[str, torch.Tensor] = {}
|
|
38
44
|
|
|
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
|
|
|
62
68
|
input = input[0]
|
|
63
69
|
stat_tensor(name, input)
|
|
64
70
|
|
|
71
|
+
def stat_output_hook(m, input, output, name):
|
|
72
|
+
if isinstance(output, tuple):
|
|
73
|
+
output = output[0]
|
|
74
|
+
stat_tensor(name, output)
|
|
75
|
+
|
|
65
76
|
for name, m in self.model.named_modules():
|
|
66
77
|
if isinstance(m, torch.nn.Linear):
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
|
|
78
|
+
if self.acts_from == "input":
|
|
79
|
+
self.hooks.append(
|
|
80
|
+
m.register_forward_pre_hook(
|
|
81
|
+
functools.partial(stat_input_hook, name=name)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else: # "output"
|
|
85
|
+
self.hooks.append(
|
|
86
|
+
m.register_forward_hook(
|
|
87
|
+
functools.partial(stat_output_hook, name=name)
|
|
88
|
+
)
|
|
70
89
|
)
|
|
71
|
-
)
|
|
72
90
|
|
|
73
91
|
def remove(self):
|
|
74
92
|
for hook in self.hooks:
|