tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -17,15 +17,17 @@ import types
|
|
|
17
17
|
from typing import Any, Callable, Dict, List, Optional
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from tqdm.auto import tqdm
|
|
20
21
|
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
22
|
+
from tico.quantization.algorithm.gptq.gptq import GPTQ
|
|
23
|
+
from tico.quantization.algorithm.gptq.utils import (
|
|
23
24
|
find_layers,
|
|
24
25
|
gather_single_batch_from_dict,
|
|
25
26
|
gather_single_batch_from_list,
|
|
26
27
|
)
|
|
27
|
-
from tico.
|
|
28
|
-
from tico.
|
|
28
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
29
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
30
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class StopForward(Exception):
|
|
@@ -34,6 +36,7 @@ class StopForward(Exception):
|
|
|
34
36
|
pass
|
|
35
37
|
|
|
36
38
|
|
|
39
|
+
@register_quantizer(GPTQConfig)
|
|
37
40
|
class GPTQQuantizer(BaseQuantizer):
|
|
38
41
|
"""
|
|
39
42
|
Quantizer for applying the GPTQ algorithm (typically for weight quantization).
|
|
@@ -43,7 +46,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
43
46
|
3) convert(model) to consume the collected data and apply GPTQ.
|
|
44
47
|
"""
|
|
45
48
|
|
|
46
|
-
def __init__(self, config:
|
|
49
|
+
def __init__(self, config: GPTQConfig):
|
|
47
50
|
super().__init__(config)
|
|
48
51
|
|
|
49
52
|
# cache_args[i] -> list of the i-th positional argument for each batch
|
|
@@ -181,7 +184,14 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
181
184
|
target_layers = [model]
|
|
182
185
|
|
|
183
186
|
quantizers: Dict[str, Any] = {}
|
|
184
|
-
for l_idx, layer in enumerate(
|
|
187
|
+
for l_idx, layer in enumerate(
|
|
188
|
+
tqdm(
|
|
189
|
+
target_layers,
|
|
190
|
+
desc="Quantizing layers",
|
|
191
|
+
unit="layer",
|
|
192
|
+
disable=not gptq_conf.show_progress,
|
|
193
|
+
)
|
|
194
|
+
):
|
|
185
195
|
# 1) Identify quantizable submodules within the layer
|
|
186
196
|
full = find_layers(layer)
|
|
187
197
|
sequential = [list(full.keys())]
|
|
@@ -210,7 +220,13 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
210
220
|
|
|
211
221
|
# Run layer forward over all cached batches to build Hessian/statistics
|
|
212
222
|
batch_num = self.num_batches
|
|
213
|
-
for batch_idx in
|
|
223
|
+
for batch_idx in tqdm(
|
|
224
|
+
range(batch_num),
|
|
225
|
+
desc=f"[L{l_idx}] collecting",
|
|
226
|
+
leave=False,
|
|
227
|
+
unit="batch",
|
|
228
|
+
disable=not gptq_conf.show_progress,
|
|
229
|
+
):
|
|
214
230
|
cache_args_batch = gather_single_batch_from_list(
|
|
215
231
|
self.cache_args, batch_idx
|
|
216
232
|
)
|
|
@@ -238,7 +254,13 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
238
254
|
gptq[name].free()
|
|
239
255
|
|
|
240
256
|
# 4) After quantization, re-run the layer to produce outputs for the next layer
|
|
241
|
-
for batch_idx in
|
|
257
|
+
for batch_idx in tqdm(
|
|
258
|
+
range(batch_num),
|
|
259
|
+
desc=f"[L{l_idx}] re-forward",
|
|
260
|
+
leave=False,
|
|
261
|
+
unit="batch",
|
|
262
|
+
disable=not gptq_conf.show_progress,
|
|
263
|
+
):
|
|
242
264
|
cache_args_batch = gather_single_batch_from_list(
|
|
243
265
|
self.cache_args, batch_idx
|
|
244
266
|
)
|
|
@@ -25,14 +25,12 @@ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObser
|
|
|
25
25
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
|
26
26
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
29
|
-
import tico.
|
|
30
|
-
import tico.
|
|
31
|
-
import tico.
|
|
32
|
-
from tico.
|
|
33
|
-
|
|
34
|
-
)
|
|
35
|
-
from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
28
|
+
from tico.quantization.algorithm.pt2e.annotation.op import *
|
|
29
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
30
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
31
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
32
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
33
|
+
from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
36
34
|
convert_scalars_to_attrs,
|
|
37
35
|
)
|
|
38
36
|
|
tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py
RENAMED
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import AddTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import Conv2DArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import DivTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import LinearArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MeanDimArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MulTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import Relu6Args
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import RsqrtArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import SubTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
22
|
-
QuantizationConfig,
|
|
23
|
-
)
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
24
22
|
|
|
25
23
|
AnnotatorType = Callable[
|
|
26
24
|
[
|
|
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
|
|
|
22
22
|
SharedQuantizationSpec,
|
|
23
23
|
)
|
|
24
24
|
|
|
25
|
-
import tico.
|
|
25
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
|
|
@@ -18,13 +18,16 @@ import torch
|
|
|
18
18
|
|
|
19
19
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.annotator import (
|
|
22
22
|
get_asymmetric_quantization_config,
|
|
23
23
|
PT2EAnnotator,
|
|
24
24
|
)
|
|
25
|
-
from tico.
|
|
25
|
+
from tico.quantization.config.pt2e import PT2EConfig
|
|
26
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
27
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
@register_quantizer(PT2EConfig)
|
|
28
31
|
class PT2EQuantizer(BaseQuantizer):
|
|
29
32
|
"""
|
|
30
33
|
Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
|
|
@@ -20,9 +20,7 @@ import torch
|
|
|
20
20
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
|
21
21
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
22
22
|
|
|
23
|
-
from tico.
|
|
24
|
-
QuantizationConfig,
|
|
25
|
-
)
|
|
23
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
def get_module_type_filter(tp: Callable):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import functools
|
|
16
|
-
from typing import Any, Dict, List
|
|
16
|
+
from typing import Any, Dict, List, Literal
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
@@ -21,18 +21,24 @@ import torch
|
|
|
21
21
|
class ChannelwiseMaxActsObserver:
|
|
22
22
|
"""
|
|
23
23
|
Observer to calcuate channelwise maximum activation
|
|
24
|
+
It supports collecting activations from either module inputs or outputs.
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
|
|
29
|
+
):
|
|
27
30
|
"""
|
|
28
31
|
model
|
|
29
32
|
A torch module whose activations are to be analyzed.
|
|
33
|
+
acts_from
|
|
34
|
+
Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
|
|
30
35
|
hooks
|
|
31
|
-
A list to store the hooks
|
|
36
|
+
A list to store the hooks registered to collect activation statistics.
|
|
32
37
|
max_acts
|
|
33
|
-
A dictionary to store the
|
|
38
|
+
A dictionary to store the per-channel maxima.
|
|
34
39
|
"""
|
|
35
40
|
self.model = model
|
|
41
|
+
self.acts_from: Literal["input", "output"] = acts_from
|
|
36
42
|
self.hooks: List[Any] = []
|
|
37
43
|
self.max_acts: Dict[str, torch.Tensor] = {}
|
|
38
44
|
|
|
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
|
|
|
62
68
|
input = input[0]
|
|
63
69
|
stat_tensor(name, input)
|
|
64
70
|
|
|
71
|
+
def stat_output_hook(m, input, output, name):
|
|
72
|
+
if isinstance(output, tuple):
|
|
73
|
+
output = output[0]
|
|
74
|
+
stat_tensor(name, output)
|
|
75
|
+
|
|
65
76
|
for name, m in self.model.named_modules():
|
|
66
77
|
if isinstance(m, torch.nn.Linear):
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
|
|
78
|
+
if self.acts_from == "input":
|
|
79
|
+
self.hooks.append(
|
|
80
|
+
m.register_forward_pre_hook(
|
|
81
|
+
functools.partial(stat_input_hook, name=name)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else: # "output"
|
|
85
|
+
self.hooks.append(
|
|
86
|
+
m.register_forward_hook(
|
|
87
|
+
functools.partial(stat_output_hook, name=name)
|
|
88
|
+
)
|
|
70
89
|
)
|
|
71
|
-
)
|
|
72
90
|
|
|
73
91
|
def remove(self):
|
|
74
92
|
for hook in self.hooks:
|
|
@@ -16,20 +16,37 @@ from typing import Any, Dict, Optional
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from tico.
|
|
20
|
-
ChannelwiseMaxActsObserver,
|
|
21
|
-
)
|
|
19
|
+
from tico.quantization.algorithm.smoothquant.observer import ChannelwiseMaxActsObserver
|
|
22
20
|
|
|
23
|
-
from tico.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from tico.
|
|
27
|
-
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
21
|
+
from tico.quantization.algorithm.smoothquant.smooth_quant import apply_smoothing
|
|
22
|
+
from tico.quantization.config.smoothquant import SmoothQuantConfig
|
|
23
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
24
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
28
25
|
|
|
29
26
|
|
|
27
|
+
@register_quantizer(SmoothQuantConfig)
|
|
30
28
|
class SmoothQuantQuantizer(BaseQuantizer):
|
|
31
29
|
"""
|
|
32
30
|
Quantizer for applying the SmoothQuant algorithm
|
|
31
|
+
|
|
32
|
+
Q) Why allow choosing between input and output activations?
|
|
33
|
+
|
|
34
|
+
SmoothQuant relies on channel-wise activation statistics to balance
|
|
35
|
+
weights and activations. In practice, there are two natural sources:
|
|
36
|
+
|
|
37
|
+
- "input": captures the tensor right before a Linear layer
|
|
38
|
+
(forward-pre-hook). This matches the original SmoothQuant paper
|
|
39
|
+
and focuses on scaling the raw hidden state.
|
|
40
|
+
|
|
41
|
+
- "output": captures the tensor right after a Linear layer
|
|
42
|
+
(forward-hook). This can better reflect post-weight dynamics,
|
|
43
|
+
especially when subsequent operations (bias, activation functions)
|
|
44
|
+
dominate the dynamic range.
|
|
45
|
+
|
|
46
|
+
Allowing both options provides flexibility: depending on model
|
|
47
|
+
architecture and calibration data, one may yield lower error than
|
|
48
|
+
the other. The default remains "input" for compatibility, but "output"
|
|
49
|
+
can be selected to empirically reduce error or runtime overhead.
|
|
33
50
|
"""
|
|
34
51
|
|
|
35
52
|
def __init__(self, config: SmoothQuantConfig):
|
|
@@ -37,6 +54,7 @@ class SmoothQuantQuantizer(BaseQuantizer):
|
|
|
37
54
|
|
|
38
55
|
self.alpha = config.alpha
|
|
39
56
|
self.custom_alpha_map = config.custom_alpha_map
|
|
57
|
+
self.acts_from = config.acts_from # "input" (default) or "output"
|
|
40
58
|
self.observer: Optional[ChannelwiseMaxActsObserver] = None
|
|
41
59
|
|
|
42
60
|
@torch.no_grad()
|
|
@@ -55,7 +73,8 @@ class SmoothQuantQuantizer(BaseQuantizer):
|
|
|
55
73
|
Returns:
|
|
56
74
|
The model prepared for SmoothQuant quantization.
|
|
57
75
|
"""
|
|
58
|
-
|
|
76
|
+
# Attach hooks according to `config.acts_from`
|
|
77
|
+
self.observer = ChannelwiseMaxActsObserver(model, acts_from=self.acts_from)
|
|
59
78
|
self.observer.attach()
|
|
60
79
|
|
|
61
80
|
return model
|