tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.observers.base import ObserverBase
22
+ from tico.quantization.wrapq.qscheme import QScheme
23
+
24
+
25
+ class AffineObserverBase(ObserverBase):
26
+ """Base for affine observers (min/max → scale/zp)."""
27
+
28
+ def __init__(
29
+ self,
30
+ *,
31
+ name: str,
32
+ dtype: DType = UINT8,
33
+ qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
34
+ channel_axis: Optional[int] = None,
35
+ ):
36
+ super().__init__(
37
+ name=name, dtype=dtype, qscheme=qscheme, channel_axis=channel_axis
38
+ )
39
+
40
+ def reset(self) -> None:
41
+ """
42
+ Reset running min/max and drop cached qparams.
43
+ """
44
+ self.min_val: torch.Tensor = torch.tensor(math.inf)
45
+ self.max_val: torch.Tensor = torch.tensor(-math.inf)
46
+ if hasattr(self, "_cached_scale"):
47
+ del self._cached_scale
48
+ if hasattr(self, "_cached_zp"):
49
+ del self._cached_zp
50
+
51
+ def load_qparams(self, scale: torch.Tensor, zp: torch.Tensor, *, lock: bool = True):
52
+ """
53
+ Inject externally computed qparams and optionally lock the observer.
54
+
55
+ When locked, subsequent `collect()` calls are ignored.
56
+ """
57
+ self._cached_scale = scale.detach()
58
+ self._cached_zp = zp.to(torch.int)
59
+ if lock:
60
+ self.enabled = False
61
+
62
+ @property
63
+ def has_qparams(self) -> bool:
64
+ return hasattr(self, "_cached_scale")
65
+
66
+ def compute_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ qmin, qmax = self.dtype.qmin, self.dtype.qmax
68
+ rng = self.max_val - self.min_val
69
+ eps = 1e-12
70
+
71
+ if self.qscheme.is_symmetric():
72
+ max_abs = torch.maximum(self.max_val.abs(), self.min_val.abs())
73
+ scale = torch.clamp(max_abs, min=eps) / qmax
74
+ zp = torch.zeros_like(scale, dtype=torch.int)
75
+ self._cached_scale, self._cached_zp = scale, zp
76
+ return scale, zp
77
+
78
+ if self.channel_axis is None:
79
+ if torch.all(rng.abs() < 1e-8):
80
+ C = self.min_val
81
+ if torch.allclose(C, torch.zeros_like(C)):
82
+ scale = torch.ones_like(C)
83
+ zp = torch.zeros_like(C, dtype=torch.int)
84
+ elif (C > 0).all():
85
+ scale = torch.clamp(C, min=eps)
86
+ zp = torch.zeros_like(C, dtype=torch.int)
87
+ else:
88
+ scale = torch.clamp(C.abs(), min=eps)
89
+ zp = torch.full_like(C, qmax, dtype=torch.int)
90
+ else:
91
+ scale = torch.clamp(rng, min=eps) / (qmax - qmin)
92
+ zp = (
93
+ torch.round(qmin - self.min_val / scale)
94
+ .clamp(qmin, qmax)
95
+ .to(torch.int)
96
+ )
97
+ else:
98
+ scale = torch.clamp(rng, min=eps) / (qmax - qmin)
99
+ zp = (
100
+ torch.round(qmin - self.min_val / scale).clamp(qmin, qmax).to(torch.int)
101
+ )
102
+
103
+ self._cached_scale, self._cached_zp = scale, zp
104
+ return scale, zp
105
+
106
+ def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.has_qparams:
108
+ raise RuntimeError(
109
+ "Call compute_qparams()/freeze_qparams() or load_qparams() first."
110
+ )
111
+ scale, zp = self._cached_scale, self._cached_zp
112
+ if self.channel_axis is None:
113
+ return torch.fake_quantize_per_tensor_affine(
114
+ x,
115
+ scale=scale,
116
+ zero_point=zp,
117
+ quant_min=self.dtype.qmin,
118
+ quant_max=self.dtype.qmax,
119
+ )
120
+ else:
121
+ return torch.fake_quantize_per_channel_affine(
122
+ x,
123
+ scale=scale,
124
+ zero_point=zp,
125
+ axis=self.channel_axis,
126
+ quant_min=self.dtype.qmin,
127
+ quant_max=self.dtype.qmax,
128
+ )
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.qscheme import QScheme
22
+
23
+
24
+ class ObserverBase(ABC):
25
+ """
26
+ Minimal abstract base for all observers/quantizers.
27
+
28
+ Subclasses must implement:
29
+ - reset()
30
+ - collect(x)
31
+ - fake_quant(x)
32
+ - compute_qparams(): optional in practice for some observers (e.g., MX),
33
+ but still part of the interface; those can return None.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ *,
39
+ name: str,
40
+ dtype: DType = UINT8,
41
+ qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
42
+ channel_axis: Optional[int] = None, # None → per-tensor
43
+ ):
44
+ self.name = name
45
+ self.dtype = dtype
46
+ self.qscheme = qscheme
47
+ self.channel_axis = channel_axis if qscheme.is_per_channel() else None
48
+ self.enabled = True
49
+ self.reset()
50
+
51
+ @abstractmethod
52
+ def reset(self) -> None:
53
+ """Clear any running statistics or cached params."""
54
+ raise NotImplementedError
55
+
56
+ def collect(self, x: torch.Tensor) -> None:
57
+ """
58
+ Update running statistics with a new batch of data.
59
+
60
+ This base implementation guards on `enabled` and then calls `_update_stats(x)`.
61
+ Subclasses should implement `_update_stats(x)` instead of overriding `collect`.
62
+ """
63
+ if not self.enabled:
64
+ return
65
+ self._update_stats(x)
66
+
67
+ @abstractmethod
68
+ def _update_stats(self, x: torch.Tensor) -> None:
69
+ """
70
+ Update running statistics (min/max, hist, mse buffers, ...).
71
+
72
+ Must be implemented by subclasses (e.g., MinMax, EMA, Histogram, MSE).
73
+ """
74
+ raise NotImplementedError
75
+
76
+ @abstractmethod
77
+ def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Apply the observer's quantization.
80
+ Implementations may or may not rely on qparams.
81
+ """
82
+ raise NotImplementedError
83
+
84
+ @abstractmethod
85
+ def compute_qparams(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
86
+ """
87
+ Compute and (if applicable) cache quantization params.
88
+ Affine observers typically return (scale, zero_point).
89
+ Observers that do not use qparams (e.g., MX) may return None.
90
+ """
91
+ raise NotImplementedError
92
+
93
+ # String repr helps debugging
94
+ def __repr__(self) -> str:
95
+ return (
96
+ f"{self.__class__.__name__}(name={self.name}, dtype={str(self.dtype)}, "
97
+ f"qscheme={str(self.qscheme)}, channel_axis={self.channel_axis}, enabled={self.enabled})"
98
+ )
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
+
20
+
21
+ class EMAObserver(AffineObserverBase):
22
+ """
23
+ Exponential-Moving-Average min/max tracker.
24
+
25
+ Why?
26
+ -----
27
+ • Smoother than raw MinMax (reduces outlier shock).
28
+ • Much cheaper than histogram/MSE observers.
29
+
30
+ The update rule follows the common "momentum" form:
31
+
32
+ ema = momentum * ema + (1 - momentum) * new_value
33
+
34
+ With momentum → 0: FAST adaptation, momentum → 1: SLOW adaptation.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ momentum: float = 0.9,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ assert 0.0 < momentum < 1.0, "momentum must be in (0, 1)"
45
+ self.momentum = momentum
46
+
47
+ @torch.no_grad()
48
+ def _update_stats(self, x: torch.Tensor):
49
+ if self.channel_axis is None:
50
+ curr_min, curr_max = x.min(), x.max()
51
+ else:
52
+ curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
53
+
54
+ if (
55
+ torch.isinf(self.min_val).any() and torch.isinf(self.max_val).any()
56
+ ): # first batch → hard init
57
+ self.min_val, self.max_val = curr_min, curr_max
58
+ return
59
+
60
+ m = self.momentum
61
+ self.min_val = m * self.min_val + (1 - m) * curr_min
62
+ self.max_val = m * self.max_val + (1 - m) * curr_max
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ IdentityObserver: a "no-op" observer for FP-only modules.
17
+
18
+ Motivation
19
+ ----------
20
+ Some layers should stay in full precision even when the rest of the model
21
+ is quantized. Attaching an `IdentityObserver` satisfies the wrapper API
22
+ (`_update_stats()`, `compute_qparams()`, `fake_quant()`) without actually
23
+ performing any statistics gathering or fake-quantization.
24
+ """
25
+ import torch
26
+
27
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
28
+
29
+
30
+ class IdentityObserver(AffineObserverBase):
31
+ """
32
+ Passthrough observer that NEVER alters the tensor.
33
+
34
+ • `_update_stats()` → does nothing
35
+ • `compute_qparams()` → returns (1.0, 0) "dummy" q-params
36
+ • `fake_quant()` → returns `x` unchanged
37
+ """
38
+
39
+ def __init__(self, **kwargs):
40
+ # Call parent so the usual fields (`dtype`, `qscheme`, …) exist,
41
+ # but immediately disable any stateful behaviour.
42
+ super().__init__(**kwargs)
43
+
44
+ # Deactivate statistics collection permanently.
45
+ self.enabled = False
46
+
47
+ # Pre-cache sentinel q-params so wrapper code that blindly
48
+ # accesses them won't crash.
49
+ self._cached_scale = torch.tensor(1.0)
50
+ self._cached_zp = torch.tensor(0, dtype=torch.int)
51
+
52
+ def reset(self) -> None: # (simple override – nothing to do)
53
+ """No internal state to reset."""
54
+ pass
55
+
56
+ def _update_stats(self, x: torch.Tensor) -> None:
57
+ """Skip statistic collection entirely."""
58
+ return
59
+
60
+ def compute_qparams(self):
61
+ """
62
+ Return the pre-cached (scale, zero_point) tuple.
63
+
64
+ Keeping the signature identical to other observers allows uniform
65
+ lifecycle management in wrapper code.
66
+ """
67
+ return self._cached_scale, self._cached_zp
68
+
69
+ def fake_quant(self, x: torch.Tensor):
70
+ """Identity mapping — leaves `x` in FP."""
71
+ return x
72
+
73
+ def __repr__(self) -> str:
74
+ return f"{self.__class__.__name__}()"
@@ -0,0 +1,39 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
+
20
+
21
+ class MinMaxObserver(AffineObserverBase):
22
+ """Plain min/max range tracker."""
23
+
24
+ @torch.no_grad()
25
+ def _update_stats(self, x: torch.Tensor) -> None:
26
+ """
27
+ Update running min/max with the incoming batch.
28
+
29
+ Per-tensor: use global min/max.
30
+ Per-channel: reduce all axes except the channel axis.
31
+ """
32
+ if self.channel_axis is None:
33
+ curr_min, curr_max = x.min(), x.max()
34
+ else:
35
+ curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
36
+
37
+ # Broadcasting handles scalar-vs-vector cases
38
+ self.min_val = torch.minimum(self.min_val, curr_min)
39
+ self.max_val = torch.maximum(self.max_val, curr_max)
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.quantization.wrapq.observers.base import ObserverBase
18
+ from tico.utils.mx.mx_ops import quantize_mx
19
+
20
+
21
+ class MXObserver(ObserverBase):
22
+ """MX (micro-scaling) observer: no min/max, no affine qparams."""
23
+
24
+ def __init__(
25
+ self,
26
+ *,
27
+ name: str,
28
+ elem_format: str = "int8",
29
+ axis: int = 0,
30
+ shared_exp_method: str = "max",
31
+ round: str = "nearest",
32
+ **base_kwargs,
33
+ ):
34
+ super().__init__(name=name, **base_kwargs)
35
+ self.elem_format = elem_format
36
+ self.axis = axis
37
+ self.shared_exp_method = shared_exp_method
38
+ self.round = round
39
+
40
+ def reset(self) -> None:
41
+ # No state to reset
42
+ return
43
+
44
+ @torch.no_grad()
45
+ def _update_stats(self, x: torch.Tensor) -> None:
46
+ # No stats required
47
+ return None
48
+
49
+ def compute_qparams(self):
50
+ # MX path does not produce affine qparams; keep interface contract.
51
+ return None
52
+
53
+ def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
54
+ return quantize_mx(
55
+ x,
56
+ elem_format=self.elem_format,
57
+ axis=self.axis,
58
+ shared_exp_method=self.shared_exp_method,
59
+ round=self.round,
60
+ )
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import auto, Enum
16
+
17
+
18
+ class QScheme(Enum):
19
+ # ───── Per-tensor ────────────
20
+ PER_TENSOR_ASYMM = auto()
21
+ PER_TENSOR_SYMM = auto()
22
+ # ───── Per-channel ───────────
23
+ PER_CHANNEL_ASYMM = auto()
24
+ PER_CHANNEL_SYMM = auto()
25
+
26
+ # helper
27
+ def is_per_channel(self) -> bool:
28
+ return self in {
29
+ QScheme.PER_CHANNEL_ASYMM,
30
+ QScheme.PER_CHANNEL_SYMM,
31
+ }
32
+
33
+ def is_symmetric(self) -> bool:
34
+ return self in {
35
+ QScheme.PER_TENSOR_SYMM,
36
+ QScheme.PER_CHANNEL_SYMM,
37
+ }
38
+
39
+ def __str__(self) -> str:
40
+ return self.name.lower()
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.quantizer import BaseQuantizer
22
+ from tico.quantization.quantizer_registry import register_quantizer
23
+
24
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
25
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
26
+
27
+
28
+ @register_quantizer(PTQConfig)
29
+ class PTQQuantizer(BaseQuantizer):
30
+ """
31
+ Post-Training Quantization (PTQ) quantizer integrated with the public interface.
32
+
33
+ Features
34
+ --------
35
+ • Automatically wraps quantizable modules using PTQWrapper.
36
+ • Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
37
+ • Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
38
+ no quantizable module was found at any boundary.
39
+ • If `strict_wrap=False`, unquantizable modules are silently skipped.
40
+ """
41
+
42
+ def __init__(self, config: PTQConfig):
43
+ super().__init__(config)
44
+ self.qcfg: PTQConfig = config
45
+ self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
46
+
47
+ @torch.no_grad()
48
+ def prepare(
49
+ self,
50
+ model: torch.nn.Module,
51
+ args: Optional[Any] = None,
52
+ kwargs: Optional[Dict[str, Any]] = None,
53
+ ):
54
+ # Wrap the tree (or single module) according to strictness policy
55
+ model = self._wrap_supported(model, self.qcfg)
56
+
57
+ # Switch all quant modules into calibration mode
58
+ if isinstance(model, QuantModuleBase):
59
+ model.enable_calibration()
60
+ for m in model.modules():
61
+ if isinstance(m, QuantModuleBase):
62
+ m.enable_calibration()
63
+ return model
64
+
65
+ @torch.no_grad()
66
+ def convert(self, model):
67
+ # Freeze qparams across the tree (QUANT mode)
68
+ if isinstance(model, QuantModuleBase):
69
+ model.freeze_qparams()
70
+ for m in model.modules():
71
+ if isinstance(m, QuantModuleBase):
72
+ m.freeze_qparams()
73
+ return model
74
+
75
+ def _wrap_supported(
76
+ self,
77
+ root: nn.Module,
78
+ qcfg: PTQConfig,
79
+ ) -> nn.Module:
80
+ """
81
+ Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
82
+ """
83
+ assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84
+
85
+ # Case A: HuggingFace-style transformers: model.model.layers
86
+ lm = getattr(root, "model", None)
87
+ layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
88
+ if isinstance(layers, nn.ModuleList):
89
+ new_list = nn.ModuleList()
90
+ for idx, layer in enumerate(layers):
91
+ child_scope = f"layer{idx}"
92
+ child_cfg = qcfg.child(child_scope)
93
+
94
+ # Enforce strictness at the child boundary
95
+ wrapped = self._try_wrap(
96
+ layer,
97
+ child_cfg,
98
+ fp_name=child_scope,
99
+ raise_on_fail=self.strict_wrap,
100
+ )
101
+ new_list.append(wrapped)
102
+ lm.layers = new_list # type: ignore[union-attr]
103
+ return root
104
+
105
+ # Case B: Containers
106
+ if isinstance(root, (nn.Sequential, nn.ModuleList)):
107
+ for i, child in enumerate(list(root)):
108
+ name = str(i)
109
+ child_cfg = qcfg.child(name)
110
+
111
+ wrapped = self._try_wrap(
112
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
113
+ )
114
+ if wrapped is child:
115
+ assert not self.strict_wrap
116
+ wrapped = self._wrap_supported(wrapped, child_cfg)
117
+ root[i] = wrapped # type: ignore[index]
118
+
119
+ if isinstance(root, nn.ModuleDict):
120
+ for k, child in list(root.items()):
121
+ name = k
122
+ child_cfg = qcfg.child(name)
123
+
124
+ wrapped = self._try_wrap(
125
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
126
+ )
127
+ if wrapped is child:
128
+ assert not self.strict_wrap
129
+ wrapped = self._wrap_supported(wrapped, child_cfg)
130
+ root[k] = wrapped # type: ignore[index]
131
+
132
+ # Case C: Leaf node
133
+ root_name = getattr(root, "_get_name", lambda: None)()
134
+ wrapped = self._try_wrap(
135
+ root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
136
+ )
137
+ if wrapped is not root:
138
+ return wrapped
139
+
140
+ assert not self.strict_wrap
141
+ # Case D: Named children
142
+ for name, child in list(root.named_children()):
143
+ child_cfg = qcfg.child(name)
144
+
145
+ wrapped = self._try_wrap(
146
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
147
+ )
148
+ if wrapped is child:
149
+ assert not self.strict_wrap
150
+ wrapped = self._wrap_supported(wrapped, child_cfg)
151
+ setattr(root, name, wrapped)
152
+
153
+ return root
154
+
155
+ def _try_wrap(
156
+ self,
157
+ module: nn.Module,
158
+ qcfg_for_child: PTQConfig,
159
+ *,
160
+ fp_name: Optional[str],
161
+ raise_on_fail: bool,
162
+ ) -> nn.Module:
163
+ """
164
+ Attempt to wrap a boundary with PTQWrapper.
165
+
166
+ Behavior:
167
+ • If PTQWrapper succeeds: return wrapped module.
168
+ • If PTQWrapper raises NotImplementedError:
169
+ - raise_on_fail=True -> re-raise (strict)
170
+ - raise_on_fail=False -> return original module (permissive)
171
+ """
172
+ try:
173
+ return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
174
+ except NotImplementedError as e:
175
+ if raise_on_fail:
176
+ raise NotImplementedError(
177
+ f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
178
+ ) from e
179
+ return module
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE