tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -16,20 +16,37 @@ from typing import Any, Dict, Optional
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.algorithm.smoothquant.observer import (
20
- ChannelwiseMaxActsObserver,
21
- )
19
+ from tico.quantization.algorithm.smoothquant.observer import ChannelwiseMaxActsObserver
22
20
 
23
- from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
24
- apply_smoothing,
25
- )
26
- from tico.experimental.quantization.config import SmoothQuantConfig
27
- from tico.experimental.quantization.quantizer import BaseQuantizer
21
+ from tico.quantization.algorithm.smoothquant.smooth_quant import apply_smoothing
22
+ from tico.quantization.config.smoothquant import SmoothQuantConfig
23
+ from tico.quantization.quantizer import BaseQuantizer
24
+ from tico.quantization.quantizer_registry import register_quantizer
28
25
 
29
26
 
27
+ @register_quantizer(SmoothQuantConfig)
30
28
  class SmoothQuantQuantizer(BaseQuantizer):
31
29
  """
32
30
  Quantizer for applying the SmoothQuant algorithm
31
+
32
+ Q) Why allow choosing between input and output activations?
33
+
34
+ SmoothQuant relies on channel-wise activation statistics to balance
35
+ weights and activations. In practice, there are two natural sources:
36
+
37
+ - "input": captures the tensor right before a Linear layer
38
+ (forward-pre-hook). This matches the original SmoothQuant paper
39
+ and focuses on scaling the raw hidden state.
40
+
41
+ - "output": captures the tensor right after a Linear layer
42
+ (forward-hook). This can better reflect post-weight dynamics,
43
+ especially when subsequent operations (bias, activation functions)
44
+ dominate the dynamic range.
45
+
46
+ Allowing both options provides flexibility: depending on model
47
+ architecture and calibration data, one may yield lower error than
48
+ the other. The default remains "input" for compatibility, but "output"
49
+ can be selected to empirically reduce error or runtime overhead.
33
50
  """
34
51
 
35
52
  def __init__(self, config: SmoothQuantConfig):
@@ -37,6 +54,7 @@ class SmoothQuantQuantizer(BaseQuantizer):
37
54
 
38
55
  self.alpha = config.alpha
39
56
  self.custom_alpha_map = config.custom_alpha_map
57
+ self.acts_from = config.acts_from # "input" (default) or "output"
40
58
  self.observer: Optional[ChannelwiseMaxActsObserver] = None
41
59
 
42
60
  @torch.no_grad()
@@ -55,7 +73,8 @@ class SmoothQuantQuantizer(BaseQuantizer):
55
73
  Returns:
56
74
  The model prepared for SmoothQuant quantization.
57
75
  """
58
- self.observer = ChannelwiseMaxActsObserver(model)
76
+ # Attach hooks according to `config.acts_from`
77
+ self.observer = ChannelwiseMaxActsObserver(model, acts_from=self.acts_from)
59
78
  self.observer.attach()
60
79
 
61
80
  return model
@@ -0,0 +1,327 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Dict, List, Optional
16
+
17
+ import torch
18
+
19
+
20
+ @torch.no_grad()
21
+ def smooth_weights(
22
+ front_module: torch.nn.Module,
23
+ back_modules: torch.nn.Module | List[torch.nn.Module],
24
+ activation_max: torch.Tensor,
25
+ alpha: float,
26
+ ):
27
+ """
28
+ Applies SmoothQuant-style smoothing to the weights and biases of two
29
+ connected modules using activation maximum values.
30
+
31
+ NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
32
+
33
+ Parameters
34
+ -----------
35
+ front_module
36
+ The front module whose weights and biases will be adjusted.
37
+ back_modules
38
+ A list of back modules whose weights and biases will be adjusted.
39
+ activation_max
40
+ A tensor of channel-wise maximum activation values for the front module.
41
+ alpha
42
+ The smoothing factor that determines the scaling for weight adjustments.
43
+
44
+ Raises
45
+ -------
46
+ AttributeError
47
+ If `front_module` or any module in `back_modules` does not have `weight` attributes.
48
+ ValueError
49
+ If the shape of tensors in `activation_max` does not match the number of channels
50
+ in `front_module`'s weight.
51
+ NoteImplementedError
52
+ If `front_module` or any module in `back_modules` is of an unsupported type.
53
+ """
54
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
55
+
56
+ if not isinstance(back_modules, list):
57
+ back_modules = [back_modules]
58
+
59
+ # Check attributes
60
+ if not hasattr(front_module, "weight"):
61
+ raise AttributeError(
62
+ f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
63
+ )
64
+ for back_m in back_modules:
65
+ if not hasattr(back_m, "weight"):
66
+ raise AttributeError(
67
+ f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
68
+ )
69
+ # Check shapes
70
+ if isinstance(front_module, LlamaRMSNorm):
71
+ front_numel = front_module.weight.numel()
72
+ else:
73
+ raise NotImplementedError(
74
+ f"Unsupported module type: {type(front_module).__name__}"
75
+ )
76
+ for back_m in back_modules:
77
+ if isinstance(back_m, torch.nn.Linear):
78
+ back_numel = back_m.in_features
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Unsupported module type: {type(front_module).__name__}"
82
+ )
83
+
84
+ if front_numel != back_numel or back_numel != activation_max.numel():
85
+ raise ValueError(
86
+ f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
87
+ )
88
+
89
+ # Compute scales
90
+ device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
91
+ activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
92
+ weight_scales = torch.cat(
93
+ [back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
94
+ dim=0,
95
+ )
96
+ weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
97
+ scales = (
98
+ (activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
99
+ .clamp(min=1e-5)
100
+ .to(device) # type: ignore[arg-type]
101
+ .to(dtype) # type: ignore[arg-type]
102
+ )
103
+
104
+ # Smooth
105
+ front_module.weight.div_(scales)
106
+ if hasattr(front_module, "bias"):
107
+ front_module.bias.div_(scales)
108
+
109
+ for back_m in back_modules:
110
+ back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
111
+
112
+
113
+ # TODO Split the files per model
114
+ # ────────────────────────────────────────────────────────────
115
+ # fairseq ReLU bridge (input-hook stats) helpers
116
+ # ────────────────────────────────────────────────────────────
117
+
118
+
119
+ @torch.no_grad()
120
+ def _compute_s_for_linear(
121
+ linear_like: torch.nn.Module, # 2D weight [out, in]
122
+ activation_max: torch.Tensor, # shape [in]
123
+ alpha: float,
124
+ ) -> torch.Tensor:
125
+ """
126
+ s = (amax^alpha / w_col_max^(1-alpha))
127
+ - amax: channel-wise max of the input to this module
128
+ - w_col_max: max(|W|) per input column
129
+ """
130
+ if not hasattr(linear_like, "weight"):
131
+ raise RuntimeError(f"{type(linear_like).__name__} has no 'weight' attribute.")
132
+ W = linear_like.weight # [out, in]
133
+ assert isinstance(W, torch.Tensor)
134
+ if W.ndim != 2:
135
+ raise RuntimeError(
136
+ f"Expected 2D weight, got {W.ndim}D for {type(linear_like).__name__}"
137
+ )
138
+
139
+ device, dtype = W.device, W.dtype
140
+ amax = activation_max.to(device=device, dtype=dtype)
141
+
142
+ if amax.numel() != W.shape[1]:
143
+ raise ValueError(
144
+ f"activation_max numel({amax.numel()}) != in_features({W.shape[1]})"
145
+ )
146
+
147
+ w_col_max = W.abs().max(dim=0)[0].clamp(min=1e-5) # [in]
148
+ s = (amax.pow(alpha) / w_col_max.pow(1.0 - alpha)).clamp(min=1e-5) # [in]
149
+ return s
150
+
151
+
152
+ @torch.no_grad()
153
+ def _fuse_relu_bridge_no_runtime_mul(
154
+ fc1: torch.nn.Module,
155
+ fc2: torch.nn.Module,
156
+ s_hidden: torch.Tensor,
157
+ ):
158
+ """
159
+ Fuse scaling across fc1 → ReLU → fc2 without runtime multiplies:
160
+ - fc1 rows *= 1/s, (fc1.bias *= 1/s)
161
+ - fc2 cols *= s
162
+ Assumes middle activation is ReLU (positive homogeneous).
163
+ """
164
+ if not hasattr(fc1, "weight") or not hasattr(fc2, "weight"):
165
+ raise RuntimeError("fc1/fc2 must have 'weight' attributes.")
166
+
167
+ W1, W2 = fc1.weight, fc2.weight
168
+ assert isinstance(W1, torch.Tensor) and isinstance(W2, torch.Tensor)
169
+ if W1.ndim != 2 or W2.ndim != 2:
170
+ raise RuntimeError("fc1/fc2 weights must be 2D.")
171
+
172
+ hidden = W1.shape[0]
173
+ if W2.shape[1] != hidden or s_hidden.numel() != hidden:
174
+ raise ValueError(
175
+ f"Dimension mismatch: hidden={hidden}, W2.in={W2.shape[1]}, s={s_hidden.numel()}"
176
+ )
177
+
178
+ s = s_hidden.to(device=W1.device, dtype=W1.dtype).clamp(min=1e-5) # [hidden]
179
+ inv_s = (1.0 / s).clamp(min=1e-5)
180
+
181
+ # fc1: row-wise scale
182
+ W1.mul_(inv_s.view(-1, 1))
183
+ if hasattr(fc1, "bias") and getattr(fc1, "bias") is not None:
184
+ assert isinstance(fc1.bias, torch.Tensor)
185
+ fc1.bias.mul_(inv_s)
186
+
187
+ # fc2: column-wise scale
188
+ W2.mul_(s.view(1, -1))
189
+
190
+
191
+ # ────────────────────────────────────────────────────────────
192
+ # Per-layer appliers (uniform protocol): return True if applied, else False
193
+ # ────────────────────────────────────────────────────────────
194
+
195
+
196
+ @torch.no_grad()
197
+ def _apply_if_llama_decoder(
198
+ name: str,
199
+ module: torch.nn.Module,
200
+ activation_max: Dict[str, torch.Tensor],
201
+ alpha_to_apply: float,
202
+ ) -> bool:
203
+ """
204
+ Apply LLaMA decoder-layer smoothing (input-hook stats).
205
+ Returns True if this handler applied smoothing to `module`.
206
+ """
207
+ try:
208
+ from transformers.models.llama.modeling_llama import ( # type: ignore
209
+ LlamaDecoderLayer,
210
+ )
211
+ except Exception:
212
+ return False
213
+
214
+ if not isinstance(module, LlamaDecoderLayer):
215
+ return False
216
+
217
+ attn_ln = module.input_layernorm
218
+ qkv = [
219
+ module.self_attn.q_proj,
220
+ module.self_attn.k_proj,
221
+ module.self_attn.v_proj,
222
+ ]
223
+ # Input-hook stats for q_proj input
224
+ qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
225
+ smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
226
+
227
+ ffn_ln = module.post_attention_layernorm
228
+ fcs = [module.mlp.gate_proj, module.mlp.up_proj]
229
+ # Input-hook stats for gate_proj input
230
+ fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
231
+ smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
232
+
233
+ return True
234
+
235
+
236
+ @torch.no_grad()
237
+ def _apply_if_fairseq_relu_bridge(
238
+ name: str,
239
+ module: torch.nn.Module,
240
+ activation_max: Dict[str, torch.Tensor],
241
+ alpha_to_apply: float,
242
+ ) -> bool:
243
+ """
244
+ Apply fairseq Transformer (Encoder/Decoder) ReLU-FFN bridge fusion
245
+ using input-hook stats at '{name}.fc1'. Returns True if applied.
246
+ """
247
+ try:
248
+ from fairseq.modules.transformer_layer import (
249
+ TransformerDecoderLayerBase,
250
+ TransformerEncoderLayerBase,
251
+ ) # type: ignore
252
+ except Exception:
253
+ return False
254
+
255
+ if not isinstance(
256
+ module, (TransformerEncoderLayerBase, TransformerDecoderLayerBase)
257
+ ):
258
+ return False
259
+
260
+ # Only when FFN activation is ReLU (positive homogeneity)
261
+ act_fn = getattr(module, "activation_fn", None)
262
+ is_relu = (act_fn is torch.nn.functional.relu) or getattr(
263
+ act_fn, "__name__", ""
264
+ ) == "relu"
265
+ if not is_relu:
266
+ return False
267
+
268
+ fc1_key = f"{name}.fc1"
269
+ amax2 = activation_max.get(fc1_key)
270
+ if amax2 is None:
271
+ return False
272
+
273
+ fc1 = getattr(module, "fc1", None)
274
+ fc2 = getattr(module, "fc2", None)
275
+ if fc1 is None or fc2 is None or not hasattr(fc2, "weight") or fc2.weight.ndim != 2:
276
+ return False
277
+
278
+ s_hidden = _compute_s_for_linear(fc2, amax2, alpha_to_apply) # [hidden]
279
+ _fuse_relu_bridge_no_runtime_mul(fc1, fc2, s_hidden)
280
+ return True
281
+
282
+
283
+ # Registry of appliers (order matters: try LLaMA first, then fairseq)
284
+ _APPLIERS: List[
285
+ Callable[[str, torch.nn.Module, Dict[str, torch.Tensor], float], bool]
286
+ ] = [
287
+ _apply_if_llama_decoder,
288
+ _apply_if_fairseq_relu_bridge,
289
+ ]
290
+
291
+
292
+ @torch.no_grad()
293
+ def apply_smoothing(
294
+ model: torch.nn.Module,
295
+ activation_max: Dict[str, torch.Tensor],
296
+ alpha: float = 0.5,
297
+ custom_alpha_map: Optional[Dict[str, float]] = None,
298
+ ):
299
+ """
300
+ Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
301
+
302
+ Parameters
303
+ -----------
304
+ model
305
+ A torch module whose weights will be smoothed.
306
+ activation_max
307
+ The channel-wise maximum activation values for the model.
308
+ alpha
309
+ The default smoothing factor to apply across all modules.
310
+ custom_alpha_map
311
+ A dictionary mapping layer/module names to custom alpha values.
312
+ Layers specified in this dictionary will use the corresponding alpha
313
+ value instead of the default.
314
+ """
315
+ for name, module in model.named_modules():
316
+ alpha_to_apply = (
317
+ custom_alpha_map.get(name, alpha) if custom_alpha_map else alpha
318
+ )
319
+ if alpha_to_apply > 1.0:
320
+ raise RuntimeError(
321
+ f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
322
+ )
323
+
324
+ # Try each applier until one succeeds.
325
+ for applier in _APPLIERS:
326
+ if applier(name, module, activation_max, alpha_to_apply):
327
+ break # applied → stop trying others
@@ -0,0 +1,26 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+
18
+ class BaseConfig(ABC):
19
+ """
20
+ Base configuration class for quantization.
21
+ """
22
+
23
+ @property
24
+ @abstractmethod
25
+ def name(self) -> str:
26
+ pass
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tico.quantization.config.base import BaseConfig
16
+
17
+
18
+ class GPTQConfig(BaseConfig):
19
+ """
20
+ Configuration for GPTQ.
21
+ """
22
+
23
+ def __init__(self, verbose: bool = False, show_progress: bool = True):
24
+ self.verbose = verbose
25
+ self.show_progress = show_progress
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return "gptq"
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tico.quantization.config.base import BaseConfig
16
+
17
+
18
+ class PT2EConfig(BaseConfig):
19
+ """
20
+ Configuration for pytorch 2.0 export quantization.
21
+ """
22
+
23
+ @property
24
+ def name(self) -> str:
25
+ return "pt2e"
@@ -0,0 +1,119 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Dict, Mapping, Type
17
+
18
+ from tico.quantization.config.base import BaseConfig
19
+ from tico.quantization.wrapq.dtypes import DType
20
+ from tico.quantization.wrapq.observers.base import ObserverBase
21
+ from tico.quantization.wrapq.observers.minmax import MinMaxObserver
22
+ from tico.quantization.wrapq.qscheme import QScheme
23
+
24
+
25
+ @dataclass
26
+ class PTQConfig(BaseConfig):
27
+ """
28
+ One object describes the quantization preferences for a single wrapper
29
+ and its descendants.
30
+
31
+ Parameters
32
+ ----------
33
+ default_dtype : DType
34
+ Fallback dtype for every observer that DOES NOT receive an explicit
35
+ override.
36
+ default_observer : Type[ObserverBase], optional
37
+ Observer class to instantiate when the caller (or an override) does
38
+ not provide a `observer` key.
39
+ default_qscheme : QScheme
40
+ Fallback quantization scheme (per-tensor / per-channel,
41
+ asymmetric / symmetric) for observers that DO NOT receive an explicit
42
+ override.
43
+ overrides : Mapping[str, Mapping[str, Any]]
44
+ Two-level mapping of scopes → observer-kwargs.
45
+
46
+ • SCOPE can be either
47
+ - the attribute name of a child wrapper
48
+ (e.g. "gate_proj" or "up_proj"), or
49
+ - an observer logical name inside this wrapper
50
+ (e.g. "mul", "act_in").
51
+
52
+ • "Observer-kwargs" is forwarded verbatim to the observer constructor
53
+ (`dtype`, `qscheme`, `channel_axis`, `observer`, …).
54
+
55
+ Example
56
+ -------
57
+ ```python
58
+ from wrapq.observers import PercentileObserver
59
+
60
+ cfg = PTQConfig(
61
+ default_dtype = DType.uint(8),
62
+ default_qscheme = QScheme.PER_TENSOR_SYMM, # <- global scheme
63
+ default_observer = PercentileObserver, # <- global algorithm
64
+ overrides={
65
+ # local override: input observer now MinMax & 4-bit, per-channel asymmetric
66
+ "act_in": {"observer": MinMaxObserver,
67
+ "dtype": DType.uint(4),
68
+ "qscheme": QScheme.PER_CHANNEL_ASYMM},
69
+ },
70
+ )
71
+ ```
72
+ """
73
+
74
+ default_dtype: DType = DType.uint(8)
75
+ default_observer: Type[ObserverBase] = MinMaxObserver
76
+ default_qscheme: QScheme = QScheme.PER_TENSOR_ASYMM
77
+ overrides: Mapping[str, Mapping[str, Any]] = field(default_factory=dict)
78
+ # If True, any module that cannot be wrapped will raise.
79
+ strict_wrap: bool = True
80
+
81
+ @property
82
+ def name(self) -> str:
83
+ return "ptq"
84
+
85
+ def get_kwargs(self, obs_name: str) -> Dict[str, Any]:
86
+ """
87
+ Return user-specified kwargs for *obs_name* inside **this** wrapper.
88
+
89
+ NOTE:
90
+ Do NOT inject a dtype/qscheme here. `_make_obs()` resolves precedence:
91
+ 1) user override (kw_cfg["dtype" | "qscheme"])
92
+ 2) wrapper's default passed to `_make_obs(..., dtype=..., qscheme=...)`
93
+ 3) self.default_dtype / `self.default_qscheme`
94
+ """
95
+ return dict(self.overrides.get(obs_name, {}))
96
+
97
+ def child(self, scope: str) -> "PTQConfig":
98
+ """
99
+ Produce a *view* for a child wrapper.
100
+
101
+ The child inherits:
102
+ • same `default_dtype`
103
+ • same `default_observer`
104
+ • same `default_qscheme`
105
+ • overrides under `self.overrides.get(scope, {})`
106
+
107
+ Other scopes remain invisible to the child.
108
+ """
109
+ sub_overrides = self.overrides.get(scope, {})
110
+ return PTQConfig(
111
+ self.default_dtype,
112
+ self.default_observer,
113
+ default_qscheme=self.default_qscheme,
114
+ overrides=sub_overrides,
115
+ strict_wrap=self.strict_wrap,
116
+ )
117
+
118
+ def __repr__(self):
119
+ return f"PTQConfig(default_dtype={self.default_dtype}, default_observer={self.default_observer}, default_qscheme={self.default_qscheme}, overrides={dict(self.overrides)}, strict_wrap={self.strict_wrap})"
@@ -12,42 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from abc import ABC, abstractmethod
16
- from typing import Dict, Optional
15
+ from typing import Dict, Literal, Optional
17
16
 
18
-
19
- class BaseConfig(ABC):
20
- """
21
- Base configuration class for quantization.
22
- """
23
-
24
- @property
25
- @abstractmethod
26
- def name(self) -> str:
27
- pass
28
-
29
-
30
- class PT2EConfig(BaseConfig):
31
- """
32
- Configuration for pytorch 2.0 export quantization.
33
- """
34
-
35
- @property
36
- def name(self) -> str:
37
- return "pt2e"
38
-
39
-
40
- class GPTQConfig(BaseConfig):
41
- """
42
- Configuration for GPTQ.
43
- """
44
-
45
- def __init__(self, verbose: bool = False):
46
- self.verbose = verbose
47
-
48
- @property
49
- def name(self) -> str:
50
- return "gptq"
17
+ from tico.quantization.config.base import BaseConfig
51
18
 
52
19
 
53
20
  class SmoothQuantConfig(BaseConfig):
@@ -59,10 +26,16 @@ class SmoothQuantConfig(BaseConfig):
59
26
  self,
60
27
  alpha: float = 0.5,
61
28
  custom_alpha_map: Optional[Dict[str, float]] = None,
29
+ acts_from: Literal["input", "output"] = "input",
62
30
  ):
63
31
  self.alpha = alpha
64
32
  self.custom_alpha_map = custom_alpha_map
33
+ # Where to collect activation statistics from:
34
+ # - "input": use forward-pre-hook (Tensor before the Linear op)
35
+ # - "output": use forward-hook (Tensor after the Linear op)
36
+ # Default is "input".
37
+ self.acts_from = acts_from
65
38
 
66
39
  @property
67
40
  def name(self) -> str:
68
- return "smooth_quant"
41
+ return "smoothquant"