tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,16 @@ import torch
18
18
 
19
19
  from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.annotator import (
21
+ from tico.quantization.algorithm.pt2e.annotation.annotator import (
22
22
  get_asymmetric_quantization_config,
23
23
  PT2EAnnotator,
24
24
  )
25
- from tico.experimental.quantization.quantizer import BaseQuantizer
25
+ from tico.quantization.config.pt2e import PT2EConfig
26
+ from tico.quantization.quantizer import BaseQuantizer
27
+ from tico.quantization.quantizer_registry import register_quantizer
26
28
 
27
29
 
30
+ @register_quantizer(PT2EConfig)
28
31
  class PT2EQuantizer(BaseQuantizer):
29
32
  """
30
33
  Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
@@ -20,9 +20,7 @@ import torch
20
20
  from torch.ao.quantization.quantizer import QuantizationSpec
21
21
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
22
22
 
23
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
24
- QuantizationConfig,
25
- )
23
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
26
24
 
27
25
 
28
26
  def get_module_type_filter(tp: Callable):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import functools
16
- from typing import Any, Dict, List
16
+ from typing import Any, Dict, List, Literal
17
17
 
18
18
  import torch
19
19
 
@@ -21,18 +21,24 @@ import torch
21
21
  class ChannelwiseMaxActsObserver:
22
22
  """
23
23
  Observer to calcuate channelwise maximum activation
24
+ It supports collecting activations from either module inputs or outputs.
24
25
  """
25
26
 
26
- def __init__(self, model):
27
+ def __init__(
28
+ self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
29
+ ):
27
30
  """
28
31
  model
29
32
  A torch module whose activations are to be analyzed.
33
+ acts_from
34
+ Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
30
35
  hooks
31
- A list to store the hooks which are registered to collect activation statistics.
36
+ A list to store the hooks registered to collect activation statistics.
32
37
  max_acts
33
- A dictionary to store the maximum activation values
38
+ A dictionary to store the per-channel maxima.
34
39
  """
35
40
  self.model = model
41
+ self.acts_from: Literal["input", "output"] = acts_from
36
42
  self.hooks: List[Any] = []
37
43
  self.max_acts: Dict[str, torch.Tensor] = {}
38
44
 
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
62
68
  input = input[0]
63
69
  stat_tensor(name, input)
64
70
 
71
+ def stat_output_hook(m, input, output, name):
72
+ if isinstance(output, tuple):
73
+ output = output[0]
74
+ stat_tensor(name, output)
75
+
65
76
  for name, m in self.model.named_modules():
66
77
  if isinstance(m, torch.nn.Linear):
67
- self.hooks.append(
68
- m.register_forward_pre_hook(
69
- functools.partial(stat_input_hook, name=name)
78
+ if self.acts_from == "input":
79
+ self.hooks.append(
80
+ m.register_forward_pre_hook(
81
+ functools.partial(stat_input_hook, name=name)
82
+ )
83
+ )
84
+ else: # "output"
85
+ self.hooks.append(
86
+ m.register_forward_hook(
87
+ functools.partial(stat_output_hook, name=name)
88
+ )
70
89
  )
71
- )
72
90
 
73
91
  def remove(self):
74
92
  for hook in self.hooks:
@@ -16,20 +16,37 @@ from typing import Any, Dict, Optional
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.algorithm.smoothquant.observer import (
20
- ChannelwiseMaxActsObserver,
21
- )
19
+ from tico.quantization.algorithm.smoothquant.observer import ChannelwiseMaxActsObserver
22
20
 
23
- from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
24
- apply_smoothing,
25
- )
26
- from tico.experimental.quantization.config import SmoothQuantConfig
27
- from tico.experimental.quantization.quantizer import BaseQuantizer
21
+ from tico.quantization.algorithm.smoothquant.smooth_quant import apply_smoothing
22
+ from tico.quantization.config.smoothquant import SmoothQuantConfig
23
+ from tico.quantization.quantizer import BaseQuantizer
24
+ from tico.quantization.quantizer_registry import register_quantizer
28
25
 
29
26
 
27
+ @register_quantizer(SmoothQuantConfig)
30
28
  class SmoothQuantQuantizer(BaseQuantizer):
31
29
  """
32
30
  Quantizer for applying the SmoothQuant algorithm
31
+
32
+ Q) Why allow choosing between input and output activations?
33
+
34
+ SmoothQuant relies on channel-wise activation statistics to balance
35
+ weights and activations. In practice, there are two natural sources:
36
+
37
+ - "input": captures the tensor right before a Linear layer
38
+ (forward-pre-hook). This matches the original SmoothQuant paper
39
+ and focuses on scaling the raw hidden state.
40
+
41
+ - "output": captures the tensor right after a Linear layer
42
+ (forward-hook). This can better reflect post-weight dynamics,
43
+ especially when subsequent operations (bias, activation functions)
44
+ dominate the dynamic range.
45
+
46
+ Allowing both options provides flexibility: depending on model
47
+ architecture and calibration data, one may yield lower error than
48
+ the other. The default remains "input" for compatibility, but "output"
49
+ can be selected to empirically reduce error or runtime overhead.
33
50
  """
34
51
 
35
52
  def __init__(self, config: SmoothQuantConfig):
@@ -37,6 +54,7 @@ class SmoothQuantQuantizer(BaseQuantizer):
37
54
 
38
55
  self.alpha = config.alpha
39
56
  self.custom_alpha_map = config.custom_alpha_map
57
+ self.acts_from = config.acts_from # "input" (default) or "output"
40
58
  self.observer: Optional[ChannelwiseMaxActsObserver] = None
41
59
 
42
60
  @torch.no_grad()
@@ -55,7 +73,8 @@ class SmoothQuantQuantizer(BaseQuantizer):
55
73
  Returns:
56
74
  The model prepared for SmoothQuant quantization.
57
75
  """
58
- self.observer = ChannelwiseMaxActsObserver(model)
76
+ # Attach hooks according to `config.acts_from`
77
+ self.observer = ChannelwiseMaxActsObserver(model, acts_from=self.acts_from)
59
78
  self.observer.attach()
60
79
 
61
80
  return model
@@ -0,0 +1,327 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Dict, List, Optional
16
+
17
+ import torch
18
+
19
+
20
+ @torch.no_grad()
21
+ def smooth_weights(
22
+ front_module: torch.nn.Module,
23
+ back_modules: torch.nn.Module | List[torch.nn.Module],
24
+ activation_max: torch.Tensor,
25
+ alpha: float,
26
+ ):
27
+ """
28
+ Applies SmoothQuant-style smoothing to the weights and biases of two
29
+ connected modules using activation maximum values.
30
+
31
+ NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
32
+
33
+ Parameters
34
+ -----------
35
+ front_module
36
+ The front module whose weights and biases will be adjusted.
37
+ back_modules
38
+ A list of back modules whose weights and biases will be adjusted.
39
+ activation_max
40
+ A tensor of channel-wise maximum activation values for the front module.
41
+ alpha
42
+ The smoothing factor that determines the scaling for weight adjustments.
43
+
44
+ Raises
45
+ -------
46
+ AttributeError
47
+ If `front_module` or any module in `back_modules` does not have `weight` attributes.
48
+ ValueError
49
+ If the shape of tensors in `activation_max` does not match the number of channels
50
+ in `front_module`'s weight.
51
+ NoteImplementedError
52
+ If `front_module` or any module in `back_modules` is of an unsupported type.
53
+ """
54
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
55
+
56
+ if not isinstance(back_modules, list):
57
+ back_modules = [back_modules]
58
+
59
+ # Check attributes
60
+ if not hasattr(front_module, "weight"):
61
+ raise AttributeError(
62
+ f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
63
+ )
64
+ for back_m in back_modules:
65
+ if not hasattr(back_m, "weight"):
66
+ raise AttributeError(
67
+ f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
68
+ )
69
+ # Check shapes
70
+ if isinstance(front_module, LlamaRMSNorm):
71
+ front_numel = front_module.weight.numel()
72
+ else:
73
+ raise NotImplementedError(
74
+ f"Unsupported module type: {type(front_module).__name__}"
75
+ )
76
+ for back_m in back_modules:
77
+ if isinstance(back_m, torch.nn.Linear):
78
+ back_numel = back_m.in_features
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Unsupported module type: {type(front_module).__name__}"
82
+ )
83
+
84
+ if front_numel != back_numel or back_numel != activation_max.numel():
85
+ raise ValueError(
86
+ f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
87
+ )
88
+
89
+ # Compute scales
90
+ device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
91
+ activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
92
+ weight_scales = torch.cat(
93
+ [back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
94
+ dim=0,
95
+ )
96
+ weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
97
+ scales = (
98
+ (activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
99
+ .clamp(min=1e-5)
100
+ .to(device) # type: ignore[arg-type]
101
+ .to(dtype) # type: ignore[arg-type]
102
+ )
103
+
104
+ # Smooth
105
+ front_module.weight.div_(scales)
106
+ if hasattr(front_module, "bias"):
107
+ front_module.bias.div_(scales)
108
+
109
+ for back_m in back_modules:
110
+ back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
111
+
112
+
113
+ # TODO Split the files per model
114
+ # ────────────────────────────────────────────────────────────
115
+ # fairseq ReLU bridge (input-hook stats) helpers
116
+ # ────────────────────────────────────────────────────────────
117
+
118
+
119
+ @torch.no_grad()
120
+ def _compute_s_for_linear(
121
+ linear_like: torch.nn.Module, # 2D weight [out, in]
122
+ activation_max: torch.Tensor, # shape [in]
123
+ alpha: float,
124
+ ) -> torch.Tensor:
125
+ """
126
+ s = (amax^alpha / w_col_max^(1-alpha))
127
+ - amax: channel-wise max of the input to this module
128
+ - w_col_max: max(|W|) per input column
129
+ """
130
+ if not hasattr(linear_like, "weight"):
131
+ raise RuntimeError(f"{type(linear_like).__name__} has no 'weight' attribute.")
132
+ W = linear_like.weight # [out, in]
133
+ assert isinstance(W, torch.Tensor)
134
+ if W.ndim != 2:
135
+ raise RuntimeError(
136
+ f"Expected 2D weight, got {W.ndim}D for {type(linear_like).__name__}"
137
+ )
138
+
139
+ device, dtype = W.device, W.dtype
140
+ amax = activation_max.to(device=device, dtype=dtype)
141
+
142
+ if amax.numel() != W.shape[1]:
143
+ raise ValueError(
144
+ f"activation_max numel({amax.numel()}) != in_features({W.shape[1]})"
145
+ )
146
+
147
+ w_col_max = W.abs().max(dim=0)[0].clamp(min=1e-5) # [in]
148
+ s = (amax.pow(alpha) / w_col_max.pow(1.0 - alpha)).clamp(min=1e-5) # [in]
149
+ return s
150
+
151
+
152
+ @torch.no_grad()
153
+ def _fuse_relu_bridge_no_runtime_mul(
154
+ fc1: torch.nn.Module,
155
+ fc2: torch.nn.Module,
156
+ s_hidden: torch.Tensor,
157
+ ):
158
+ """
159
+ Fuse scaling across fc1 → ReLU → fc2 without runtime multiplies:
160
+ - fc1 rows *= 1/s, (fc1.bias *= 1/s)
161
+ - fc2 cols *= s
162
+ Assumes middle activation is ReLU (positive homogeneous).
163
+ """
164
+ if not hasattr(fc1, "weight") or not hasattr(fc2, "weight"):
165
+ raise RuntimeError("fc1/fc2 must have 'weight' attributes.")
166
+
167
+ W1, W2 = fc1.weight, fc2.weight
168
+ assert isinstance(W1, torch.Tensor) and isinstance(W2, torch.Tensor)
169
+ if W1.ndim != 2 or W2.ndim != 2:
170
+ raise RuntimeError("fc1/fc2 weights must be 2D.")
171
+
172
+ hidden = W1.shape[0]
173
+ if W2.shape[1] != hidden or s_hidden.numel() != hidden:
174
+ raise ValueError(
175
+ f"Dimension mismatch: hidden={hidden}, W2.in={W2.shape[1]}, s={s_hidden.numel()}"
176
+ )
177
+
178
+ s = s_hidden.to(device=W1.device, dtype=W1.dtype).clamp(min=1e-5) # [hidden]
179
+ inv_s = (1.0 / s).clamp(min=1e-5)
180
+
181
+ # fc1: row-wise scale
182
+ W1.mul_(inv_s.view(-1, 1))
183
+ if hasattr(fc1, "bias") and getattr(fc1, "bias") is not None:
184
+ assert isinstance(fc1.bias, torch.Tensor)
185
+ fc1.bias.mul_(inv_s)
186
+
187
+ # fc2: column-wise scale
188
+ W2.mul_(s.view(1, -1))
189
+
190
+
191
+ # ────────────────────────────────────────────────────────────
192
+ # Per-layer appliers (uniform protocol): return True if applied, else False
193
+ # ────────────────────────────────────────────────────────────
194
+
195
+
196
+ @torch.no_grad()
197
+ def _apply_if_llama_decoder(
198
+ name: str,
199
+ module: torch.nn.Module,
200
+ activation_max: Dict[str, torch.Tensor],
201
+ alpha_to_apply: float,
202
+ ) -> bool:
203
+ """
204
+ Apply LLaMA decoder-layer smoothing (input-hook stats).
205
+ Returns True if this handler applied smoothing to `module`.
206
+ """
207
+ try:
208
+ from transformers.models.llama.modeling_llama import ( # type: ignore
209
+ LlamaDecoderLayer,
210
+ )
211
+ except Exception:
212
+ return False
213
+
214
+ if not isinstance(module, LlamaDecoderLayer):
215
+ return False
216
+
217
+ attn_ln = module.input_layernorm
218
+ qkv = [
219
+ module.self_attn.q_proj,
220
+ module.self_attn.k_proj,
221
+ module.self_attn.v_proj,
222
+ ]
223
+ # Input-hook stats for q_proj input
224
+ qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
225
+ smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
226
+
227
+ ffn_ln = module.post_attention_layernorm
228
+ fcs = [module.mlp.gate_proj, module.mlp.up_proj]
229
+ # Input-hook stats for gate_proj input
230
+ fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
231
+ smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
232
+
233
+ return True
234
+
235
+
236
+ @torch.no_grad()
237
+ def _apply_if_fairseq_relu_bridge(
238
+ name: str,
239
+ module: torch.nn.Module,
240
+ activation_max: Dict[str, torch.Tensor],
241
+ alpha_to_apply: float,
242
+ ) -> bool:
243
+ """
244
+ Apply fairseq Transformer (Encoder/Decoder) ReLU-FFN bridge fusion
245
+ using input-hook stats at '{name}.fc1'. Returns True if applied.
246
+ """
247
+ try:
248
+ from fairseq.modules.transformer_layer import (
249
+ TransformerDecoderLayerBase,
250
+ TransformerEncoderLayerBase,
251
+ ) # type: ignore
252
+ except Exception:
253
+ return False
254
+
255
+ if not isinstance(
256
+ module, (TransformerEncoderLayerBase, TransformerDecoderLayerBase)
257
+ ):
258
+ return False
259
+
260
+ # Only when FFN activation is ReLU (positive homogeneity)
261
+ act_fn = getattr(module, "activation_fn", None)
262
+ is_relu = (act_fn is torch.nn.functional.relu) or getattr(
263
+ act_fn, "__name__", ""
264
+ ) == "relu"
265
+ if not is_relu:
266
+ return False
267
+
268
+ fc1_key = f"{name}.fc1"
269
+ amax2 = activation_max.get(fc1_key)
270
+ if amax2 is None:
271
+ return False
272
+
273
+ fc1 = getattr(module, "fc1", None)
274
+ fc2 = getattr(module, "fc2", None)
275
+ if fc1 is None or fc2 is None or not hasattr(fc2, "weight") or fc2.weight.ndim != 2:
276
+ return False
277
+
278
+ s_hidden = _compute_s_for_linear(fc2, amax2, alpha_to_apply) # [hidden]
279
+ _fuse_relu_bridge_no_runtime_mul(fc1, fc2, s_hidden)
280
+ return True
281
+
282
+
283
+ # Registry of appliers (order matters: try LLaMA first, then fairseq)
284
+ _APPLIERS: List[
285
+ Callable[[str, torch.nn.Module, Dict[str, torch.Tensor], float], bool]
286
+ ] = [
287
+ _apply_if_llama_decoder,
288
+ _apply_if_fairseq_relu_bridge,
289
+ ]
290
+
291
+
292
+ @torch.no_grad()
293
+ def apply_smoothing(
294
+ model: torch.nn.Module,
295
+ activation_max: Dict[str, torch.Tensor],
296
+ alpha: float = 0.5,
297
+ custom_alpha_map: Optional[Dict[str, float]] = None,
298
+ ):
299
+ """
300
+ Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
301
+
302
+ Parameters
303
+ -----------
304
+ model
305
+ A torch module whose weights will be smoothed.
306
+ activation_max
307
+ The channel-wise maximum activation values for the model.
308
+ alpha
309
+ The default smoothing factor to apply across all modules.
310
+ custom_alpha_map
311
+ A dictionary mapping layer/module names to custom alpha values.
312
+ Layers specified in this dictionary will use the corresponding alpha
313
+ value instead of the default.
314
+ """
315
+ for name, module in model.named_modules():
316
+ alpha_to_apply = (
317
+ custom_alpha_map.get(name, alpha) if custom_alpha_map else alpha
318
+ )
319
+ if alpha_to_apply > 1.0:
320
+ raise RuntimeError(
321
+ f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
322
+ )
323
+
324
+ # Try each applier until one succeeds.
325
+ for applier in _APPLIERS:
326
+ if applier(name, module, activation_max, alpha_to_apply):
327
+ break # applied → stop trying others
@@ -0,0 +1,26 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+
18
+ class BaseConfig(ABC):
19
+ """
20
+ Base configuration class for quantization.
21
+ """
22
+
23
+ @property
24
+ @abstractmethod
25
+ def name(self) -> str:
26
+ pass
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tico.quantization.config.gptq import GPTQConfig
16
+
17
+
18
+ class FPIGPTQConfig(GPTQConfig):
19
+ """
20
+ Configuration for FPIGPTQ (Fixed Point Iteration).
21
+ """
22
+
23
+ def __init__(self, verbose: bool = False, show_progress: bool = True):
24
+ self.verbose = verbose
25
+ self.show_progress = show_progress
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return "fpi_gptq"
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tico.quantization.config.base import BaseConfig
16
+
17
+
18
+ class GPTQConfig(BaseConfig):
19
+ """
20
+ Configuration for GPTQ.
21
+ """
22
+
23
+ def __init__(self, verbose: bool = False, show_progress: bool = True):
24
+ self.verbose = verbose
25
+ self.show_progress = show_progress
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return "gptq"
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tico.quantization.config.base import BaseConfig
16
+
17
+
18
+ class PT2EConfig(BaseConfig):
19
+ """
20
+ Configuration for pytorch 2.0 export quantization.
21
+ """
22
+
23
+ @property
24
+ def name(self) -> str:
25
+ return "pt2e"
@@ -15,14 +15,15 @@
15
15
  from dataclasses import dataclass, field
16
16
  from typing import Any, Dict, Mapping, Type
17
17
 
18
- from tico.experimental.quantization.ptq.dtypes import DType
19
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
20
- from tico.experimental.quantization.ptq.observers.minmax import MinMaxObserver
21
- from tico.experimental.quantization.ptq.qscheme import QScheme
18
+ from tico.quantization.config.base import BaseConfig
19
+ from tico.quantization.wrapq.dtypes import DType
20
+ from tico.quantization.wrapq.observers.base import ObserverBase
21
+ from tico.quantization.wrapq.observers.minmax import MinMaxObserver
22
+ from tico.quantization.wrapq.qscheme import QScheme
22
23
 
23
24
 
24
25
  @dataclass
25
- class QuantConfig:
26
+ class PTQConfig(BaseConfig):
26
27
  """
27
28
  One object describes the quantization preferences for a single wrapper
28
29
  and its descendants.
@@ -54,9 +55,9 @@ class QuantConfig:
54
55
  Example
55
56
  -------
56
57
  ```python
57
- from ptq.observers import PercentileObserver
58
+ from wrapq.observers import PercentileObserver
58
59
 
59
- cfg = QuantConfig(
60
+ cfg = PTQConfig(
60
61
  default_dtype = DType.uint(8),
61
62
  default_qscheme = QScheme.PER_TENSOR_SYMM, # <- global scheme
62
63
  default_observer = PercentileObserver, # <- global algorithm
@@ -74,6 +75,12 @@ class QuantConfig:
74
75
  default_observer: Type[ObserverBase] = MinMaxObserver
75
76
  default_qscheme: QScheme = QScheme.PER_TENSOR_ASYMM
76
77
  overrides: Mapping[str, Mapping[str, Any]] = field(default_factory=dict)
78
+ # If True, any module that cannot be wrapped will raise.
79
+ strict_wrap: bool = True
80
+
81
+ @property
82
+ def name(self) -> str:
83
+ return "ptq"
77
84
 
78
85
  def get_kwargs(self, obs_name: str) -> Dict[str, Any]:
79
86
  """
@@ -87,7 +94,7 @@ class QuantConfig:
87
94
  """
88
95
  return dict(self.overrides.get(obs_name, {}))
89
96
 
90
- def child(self, scope: str) -> "QuantConfig":
97
+ def child(self, scope: str) -> "PTQConfig":
91
98
  """
92
99
  Produce a *view* for a child wrapper.
93
100
 
@@ -100,12 +107,13 @@ class QuantConfig:
100
107
  Other scopes remain invisible to the child.
101
108
  """
102
109
  sub_overrides = self.overrides.get(scope, {})
103
- return QuantConfig(
110
+ return PTQConfig(
104
111
  self.default_dtype,
105
112
  self.default_observer,
106
113
  default_qscheme=self.default_qscheme,
107
114
  overrides=sub_overrides,
115
+ strict_wrap=self.strict_wrap,
108
116
  )
109
117
 
110
118
  def __repr__(self):
111
- return f"QuantConfig(default_dtype={self.default_dtype}, default_observer={self.default_observer}, default_qscheme={self.default_qscheme}, overrides={dict(self.overrides)})"
119
+ return f"PTQConfig(default_dtype={self.default_dtype}, default_observer={self.default_observer}, default_qscheme={self.default_qscheme}, overrides={dict(self.overrides)}, strict_wrap={self.strict_wrap})"