tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py
RENAMED
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import AddTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import Conv2DArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import DivTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.ao.quantization.quantizer import DerivedQuantizationSpec
|
|
21
21
|
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
import tico.
|
|
25
|
-
from tico.
|
|
26
|
-
QuantizationConfig,
|
|
27
|
-
)
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
23
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
24
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
25
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
28
26
|
from tico.utils.validate_args_kwargs import LinearArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MeanDimArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import MulTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import Relu6Args
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import RsqrtArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import tico.
|
|
22
|
-
import tico.
|
|
23
|
-
import tico.
|
|
24
|
-
from tico.
|
|
25
|
-
QuantizationConfig,
|
|
26
|
-
)
|
|
21
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
22
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
23
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
24
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
27
25
|
from tico.utils.validate_args_kwargs import SubTensorArgs
|
|
28
26
|
|
|
29
27
|
|
|
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
|
|
|
18
18
|
import torch.fx
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
22
|
-
QuantizationConfig,
|
|
23
|
-
)
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
24
22
|
|
|
25
23
|
AnnotatorType = Callable[
|
|
26
24
|
[
|
|
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
|
|
|
22
22
|
SharedQuantizationSpec,
|
|
23
23
|
)
|
|
24
24
|
|
|
25
|
-
import tico.
|
|
25
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
|
|
@@ -18,13 +18,16 @@ import torch
|
|
|
18
18
|
|
|
19
19
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
20
20
|
|
|
21
|
-
from tico.
|
|
21
|
+
from tico.quantization.algorithm.pt2e.annotation.annotator import (
|
|
22
22
|
get_asymmetric_quantization_config,
|
|
23
23
|
PT2EAnnotator,
|
|
24
24
|
)
|
|
25
|
-
from tico.
|
|
25
|
+
from tico.quantization.config.pt2e import PT2EConfig
|
|
26
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
27
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
26
28
|
|
|
27
29
|
|
|
30
|
+
@register_quantizer(PT2EConfig)
|
|
28
31
|
class PT2EQuantizer(BaseQuantizer):
|
|
29
32
|
"""
|
|
30
33
|
Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
|
|
@@ -20,9 +20,7 @@ import torch
|
|
|
20
20
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
|
21
21
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
22
22
|
|
|
23
|
-
from tico.
|
|
24
|
-
QuantizationConfig,
|
|
25
|
-
)
|
|
23
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
def get_module_type_filter(tp: Callable):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import functools
|
|
16
|
-
from typing import Any, Dict, List
|
|
16
|
+
from typing import Any, Dict, List, Literal
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
@@ -21,18 +21,24 @@ import torch
|
|
|
21
21
|
class ChannelwiseMaxActsObserver:
|
|
22
22
|
"""
|
|
23
23
|
Observer to calcuate channelwise maximum activation
|
|
24
|
+
It supports collecting activations from either module inputs or outputs.
|
|
24
25
|
"""
|
|
25
26
|
|
|
26
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
|
|
29
|
+
):
|
|
27
30
|
"""
|
|
28
31
|
model
|
|
29
32
|
A torch module whose activations are to be analyzed.
|
|
33
|
+
acts_from
|
|
34
|
+
Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
|
|
30
35
|
hooks
|
|
31
|
-
A list to store the hooks
|
|
36
|
+
A list to store the hooks registered to collect activation statistics.
|
|
32
37
|
max_acts
|
|
33
|
-
A dictionary to store the
|
|
38
|
+
A dictionary to store the per-channel maxima.
|
|
34
39
|
"""
|
|
35
40
|
self.model = model
|
|
41
|
+
self.acts_from: Literal["input", "output"] = acts_from
|
|
36
42
|
self.hooks: List[Any] = []
|
|
37
43
|
self.max_acts: Dict[str, torch.Tensor] = {}
|
|
38
44
|
|
|
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
|
|
|
62
68
|
input = input[0]
|
|
63
69
|
stat_tensor(name, input)
|
|
64
70
|
|
|
71
|
+
def stat_output_hook(m, input, output, name):
|
|
72
|
+
if isinstance(output, tuple):
|
|
73
|
+
output = output[0]
|
|
74
|
+
stat_tensor(name, output)
|
|
75
|
+
|
|
65
76
|
for name, m in self.model.named_modules():
|
|
66
77
|
if isinstance(m, torch.nn.Linear):
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
|
|
78
|
+
if self.acts_from == "input":
|
|
79
|
+
self.hooks.append(
|
|
80
|
+
m.register_forward_pre_hook(
|
|
81
|
+
functools.partial(stat_input_hook, name=name)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else: # "output"
|
|
85
|
+
self.hooks.append(
|
|
86
|
+
m.register_forward_hook(
|
|
87
|
+
functools.partial(stat_output_hook, name=name)
|
|
88
|
+
)
|
|
70
89
|
)
|
|
71
|
-
)
|
|
72
90
|
|
|
73
91
|
def remove(self):
|
|
74
92
|
for hook in self.hooks:
|
|
@@ -16,20 +16,37 @@ from typing import Any, Dict, Optional
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from tico.
|
|
20
|
-
ChannelwiseMaxActsObserver,
|
|
21
|
-
)
|
|
19
|
+
from tico.quantization.algorithm.smoothquant.observer import ChannelwiseMaxActsObserver
|
|
22
20
|
|
|
23
|
-
from tico.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from tico.
|
|
27
|
-
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
21
|
+
from tico.quantization.algorithm.smoothquant.smooth_quant import apply_smoothing
|
|
22
|
+
from tico.quantization.config.smoothquant import SmoothQuantConfig
|
|
23
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
24
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
28
25
|
|
|
29
26
|
|
|
27
|
+
@register_quantizer(SmoothQuantConfig)
|
|
30
28
|
class SmoothQuantQuantizer(BaseQuantizer):
|
|
31
29
|
"""
|
|
32
30
|
Quantizer for applying the SmoothQuant algorithm
|
|
31
|
+
|
|
32
|
+
Q) Why allow choosing between input and output activations?
|
|
33
|
+
|
|
34
|
+
SmoothQuant relies on channel-wise activation statistics to balance
|
|
35
|
+
weights and activations. In practice, there are two natural sources:
|
|
36
|
+
|
|
37
|
+
- "input": captures the tensor right before a Linear layer
|
|
38
|
+
(forward-pre-hook). This matches the original SmoothQuant paper
|
|
39
|
+
and focuses on scaling the raw hidden state.
|
|
40
|
+
|
|
41
|
+
- "output": captures the tensor right after a Linear layer
|
|
42
|
+
(forward-hook). This can better reflect post-weight dynamics,
|
|
43
|
+
especially when subsequent operations (bias, activation functions)
|
|
44
|
+
dominate the dynamic range.
|
|
45
|
+
|
|
46
|
+
Allowing both options provides flexibility: depending on model
|
|
47
|
+
architecture and calibration data, one may yield lower error than
|
|
48
|
+
the other. The default remains "input" for compatibility, but "output"
|
|
49
|
+
can be selected to empirically reduce error or runtime overhead.
|
|
33
50
|
"""
|
|
34
51
|
|
|
35
52
|
def __init__(self, config: SmoothQuantConfig):
|
|
@@ -37,6 +54,7 @@ class SmoothQuantQuantizer(BaseQuantizer):
|
|
|
37
54
|
|
|
38
55
|
self.alpha = config.alpha
|
|
39
56
|
self.custom_alpha_map = config.custom_alpha_map
|
|
57
|
+
self.acts_from = config.acts_from # "input" (default) or "output"
|
|
40
58
|
self.observer: Optional[ChannelwiseMaxActsObserver] = None
|
|
41
59
|
|
|
42
60
|
@torch.no_grad()
|
|
@@ -55,7 +73,8 @@ class SmoothQuantQuantizer(BaseQuantizer):
|
|
|
55
73
|
Returns:
|
|
56
74
|
The model prepared for SmoothQuant quantization.
|
|
57
75
|
"""
|
|
58
|
-
|
|
76
|
+
# Attach hooks according to `config.acts_from`
|
|
77
|
+
self.observer = ChannelwiseMaxActsObserver(model, acts_from=self.acts_from)
|
|
59
78
|
self.observer.attach()
|
|
60
79
|
|
|
61
80
|
return model
|