tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +128 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
@@ -17,15 +17,17 @@ import types
17
17
  from typing import Any, Callable, Dict, List, Optional
18
18
 
19
19
  import torch
20
+ from tqdm.auto import tqdm
20
21
 
21
- from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
22
- from tico.experimental.quantization.algorithm.gptq.utils import (
22
+ from tico.quantization.algorithm.gptq.gptq import GPTQ
23
+ from tico.quantization.algorithm.gptq.utils import (
23
24
  find_layers,
24
25
  gather_single_batch_from_dict,
25
26
  gather_single_batch_from_list,
26
27
  )
27
- from tico.experimental.quantization.config import BaseConfig, GPTQConfig
28
- from tico.experimental.quantization.quantizer import BaseQuantizer
28
+ from tico.quantization.config.gptq import GPTQConfig
29
+ from tico.quantization.quantizer import BaseQuantizer
30
+ from tico.quantization.quantizer_registry import register_quantizer
29
31
 
30
32
 
31
33
  class StopForward(Exception):
@@ -34,6 +36,7 @@ class StopForward(Exception):
34
36
  pass
35
37
 
36
38
 
39
+ @register_quantizer(GPTQConfig)
37
40
  class GPTQQuantizer(BaseQuantizer):
38
41
  """
39
42
  Quantizer for applying the GPTQ algorithm (typically for weight quantization).
@@ -43,7 +46,7 @@ class GPTQQuantizer(BaseQuantizer):
43
46
  3) convert(model) to consume the collected data and apply GPTQ.
44
47
  """
45
48
 
46
- def __init__(self, config: BaseConfig):
49
+ def __init__(self, config: GPTQConfig):
47
50
  super().__init__(config)
48
51
 
49
52
  # cache_args[i] -> list of the i-th positional argument for each batch
@@ -181,7 +184,14 @@ class GPTQQuantizer(BaseQuantizer):
181
184
  target_layers = [model]
182
185
 
183
186
  quantizers: Dict[str, Any] = {}
184
- for l_idx, layer in enumerate(target_layers):
187
+ for l_idx, layer in enumerate(
188
+ tqdm(
189
+ target_layers,
190
+ desc="Quantizing layers",
191
+ unit="layer",
192
+ disable=not gptq_conf.show_progress,
193
+ )
194
+ ):
185
195
  # 1) Identify quantizable submodules within the layer
186
196
  full = find_layers(layer)
187
197
  sequential = [list(full.keys())]
@@ -210,7 +220,13 @@ class GPTQQuantizer(BaseQuantizer):
210
220
 
211
221
  # Run layer forward over all cached batches to build Hessian/statistics
212
222
  batch_num = self.num_batches
213
- for batch_idx in range(batch_num):
223
+ for batch_idx in tqdm(
224
+ range(batch_num),
225
+ desc=f"[L{l_idx}] collecting",
226
+ leave=False,
227
+ unit="batch",
228
+ disable=not gptq_conf.show_progress,
229
+ ):
214
230
  cache_args_batch = gather_single_batch_from_list(
215
231
  self.cache_args, batch_idx
216
232
  )
@@ -238,7 +254,13 @@ class GPTQQuantizer(BaseQuantizer):
238
254
  gptq[name].free()
239
255
 
240
256
  # 4) After quantization, re-run the layer to produce outputs for the next layer
241
- for batch_idx in range(batch_num):
257
+ for batch_idx in tqdm(
258
+ range(batch_num),
259
+ desc=f"[L{l_idx}] re-forward",
260
+ leave=False,
261
+ unit="batch",
262
+ disable=not gptq_conf.show_progress,
263
+ ):
242
264
  cache_args_batch = gather_single_batch_from_list(
243
265
  self.cache_args, batch_idx
244
266
  )
@@ -25,14 +25,12 @@ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObser
25
25
  from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
26
26
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
27
27
 
28
- from tico.experimental.quantization.algorithm.pt2e.annotation.op import *
29
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
30
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
31
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
32
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
33
- QuantizationConfig,
34
- )
35
- from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
28
+ from tico.quantization.algorithm.pt2e.annotation.op import *
29
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
30
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
31
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
32
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
33
+ from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
36
34
  convert_scalars_to_attrs,
37
35
  )
38
36
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import SharedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import AddTensorArgs
28
26
 
29
27
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import Conv2DArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import DivTensorArgs
28
26
 
29
27
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import LinearArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MeanDimArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MulTensorArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import Relu6Args
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import RsqrtArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import SubTensorArgs
28
26
 
29
27
 
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
22
- QuantizationConfig,
23
- )
21
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
24
22
 
25
23
  AnnotatorType = Callable[
26
24
  [
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
22
22
  SharedQuantizationSpec,
23
23
  )
24
24
 
25
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
25
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
26
26
 
27
27
 
28
28
  def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
@@ -18,13 +18,16 @@ import torch
18
18
 
19
19
  from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.annotator import (
21
+ from tico.quantization.algorithm.pt2e.annotation.annotator import (
22
22
  get_asymmetric_quantization_config,
23
23
  PT2EAnnotator,
24
24
  )
25
- from tico.experimental.quantization.quantizer import BaseQuantizer
25
+ from tico.quantization.config.pt2e import PT2EConfig
26
+ from tico.quantization.quantizer import BaseQuantizer
27
+ from tico.quantization.quantizer_registry import register_quantizer
26
28
 
27
29
 
30
+ @register_quantizer(PT2EConfig)
28
31
  class PT2EQuantizer(BaseQuantizer):
29
32
  """
30
33
  Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
@@ -20,9 +20,7 @@ import torch
20
20
  from torch.ao.quantization.quantizer import QuantizationSpec
21
21
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
22
22
 
23
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
24
- QuantizationConfig,
25
- )
23
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
26
24
 
27
25
 
28
26
  def get_module_type_filter(tp: Callable):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import functools
16
- from typing import Any, Dict, List
16
+ from typing import Any, Dict, List, Literal
17
17
 
18
18
  import torch
19
19
 
@@ -21,18 +21,24 @@ import torch
21
21
  class ChannelwiseMaxActsObserver:
22
22
  """
23
23
  Observer to calcuate channelwise maximum activation
24
+ It supports collecting activations from either module inputs or outputs.
24
25
  """
25
26
 
26
- def __init__(self, model):
27
+ def __init__(
28
+ self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
29
+ ):
27
30
  """
28
31
  model
29
32
  A torch module whose activations are to be analyzed.
33
+ acts_from
34
+ Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
30
35
  hooks
31
- A list to store the hooks which are registered to collect activation statistics.
36
+ A list to store the hooks registered to collect activation statistics.
32
37
  max_acts
33
- A dictionary to store the maximum activation values
38
+ A dictionary to store the per-channel maxima.
34
39
  """
35
40
  self.model = model
41
+ self.acts_from: Literal["input", "output"] = acts_from
36
42
  self.hooks: List[Any] = []
37
43
  self.max_acts: Dict[str, torch.Tensor] = {}
38
44
 
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
62
68
  input = input[0]
63
69
  stat_tensor(name, input)
64
70
 
71
+ def stat_output_hook(m, input, output, name):
72
+ if isinstance(output, tuple):
73
+ output = output[0]
74
+ stat_tensor(name, output)
75
+
65
76
  for name, m in self.model.named_modules():
66
77
  if isinstance(m, torch.nn.Linear):
67
- self.hooks.append(
68
- m.register_forward_pre_hook(
69
- functools.partial(stat_input_hook, name=name)
78
+ if self.acts_from == "input":
79
+ self.hooks.append(
80
+ m.register_forward_pre_hook(
81
+ functools.partial(stat_input_hook, name=name)
82
+ )
83
+ )
84
+ else: # "output"
85
+ self.hooks.append(
86
+ m.register_forward_hook(
87
+ functools.partial(stat_output_hook, name=name)
88
+ )
70
89
  )
71
- )
72
90
 
73
91
  def remove(self):
74
92
  for hook in self.hooks:
@@ -16,20 +16,37 @@ from typing import Any, Dict, Optional
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.algorithm.smoothquant.observer import (
20
- ChannelwiseMaxActsObserver,
21
- )
19
+ from tico.quantization.algorithm.smoothquant.observer import ChannelwiseMaxActsObserver
22
20
 
23
- from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
24
- apply_smoothing,
25
- )
26
- from tico.experimental.quantization.config import SmoothQuantConfig
27
- from tico.experimental.quantization.quantizer import BaseQuantizer
21
+ from tico.quantization.algorithm.smoothquant.smooth_quant import apply_smoothing
22
+ from tico.quantization.config.smoothquant import SmoothQuantConfig
23
+ from tico.quantization.quantizer import BaseQuantizer
24
+ from tico.quantization.quantizer_registry import register_quantizer
28
25
 
29
26
 
27
+ @register_quantizer(SmoothQuantConfig)
30
28
  class SmoothQuantQuantizer(BaseQuantizer):
31
29
  """
32
30
  Quantizer for applying the SmoothQuant algorithm
31
+
32
+ Q) Why allow choosing between input and output activations?
33
+
34
+ SmoothQuant relies on channel-wise activation statistics to balance
35
+ weights and activations. In practice, there are two natural sources:
36
+
37
+ - "input": captures the tensor right before a Linear layer
38
+ (forward-pre-hook). This matches the original SmoothQuant paper
39
+ and focuses on scaling the raw hidden state.
40
+
41
+ - "output": captures the tensor right after a Linear layer
42
+ (forward-hook). This can better reflect post-weight dynamics,
43
+ especially when subsequent operations (bias, activation functions)
44
+ dominate the dynamic range.
45
+
46
+ Allowing both options provides flexibility: depending on model
47
+ architecture and calibration data, one may yield lower error than
48
+ the other. The default remains "input" for compatibility, but "output"
49
+ can be selected to empirically reduce error or runtime overhead.
33
50
  """
34
51
 
35
52
  def __init__(self, config: SmoothQuantConfig):
@@ -37,6 +54,7 @@ class SmoothQuantQuantizer(BaseQuantizer):
37
54
 
38
55
  self.alpha = config.alpha
39
56
  self.custom_alpha_map = config.custom_alpha_map
57
+ self.acts_from = config.acts_from # "input" (default) or "output"
40
58
  self.observer: Optional[ChannelwiseMaxActsObserver] = None
41
59
 
42
60
  @torch.no_grad()
@@ -55,7 +73,8 @@ class SmoothQuantQuantizer(BaseQuantizer):
55
73
  Returns:
56
74
  The model prepared for SmoothQuant quantization.
57
75
  """
58
- self.observer = ChannelwiseMaxActsObserver(model)
76
+ # Attach hooks according to `config.acts_from`
77
+ self.observer = ChannelwiseMaxActsObserver(model, acts_from=self.acts_from)
59
78
  self.observer.attach()
60
79
 
61
80
  return model