tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ # Copyright (c) 2024 Intel Corporation
2
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import types
17
+ from typing import Any, Callable, Dict, List, Optional
18
+
19
+ import torch
20
+ from tqdm.auto import tqdm
21
+
22
+ from tico.quantization.algorithm.gptq.gptq import GPTQ
23
+ from tico.quantization.algorithm.gptq.utils import (
24
+ find_layers,
25
+ gather_single_batch_from_dict,
26
+ gather_single_batch_from_list,
27
+ )
28
+ from tico.quantization.config.gptq import GPTQConfig
29
+ from tico.quantization.quantizer import BaseQuantizer
30
+ from tico.quantization.quantizer_registry import register_quantizer
31
+
32
+
33
+ class StopForward(Exception):
34
+ """Custom exception used to stop the forward pass after the first layer."""
35
+
36
+ pass
37
+
38
+
39
+ @register_quantizer(GPTQConfig)
40
+ class GPTQQuantizer(BaseQuantizer):
41
+ """
42
+ Quantizer for applying the GPTQ algorithm (typically for weight quantization).
43
+ This implementation expects:
44
+ 1) prepare(model, ...) to only attach hooks/Catchers and NOT run the model internally.
45
+ 2) The user runs the model with arbitrary number of batches to collect calibration data.
46
+ 3) convert(model) to consume the collected data and apply GPTQ.
47
+ """
48
+
49
+ def __init__(self, config: GPTQConfig):
50
+ super().__init__(config)
51
+
52
+ # cache_args[i] -> list of the i-th positional argument for each batch
53
+ self.cache_args: List[List[Any]] = []
54
+ # cache_kwargs[k] -> list of the value for keyword k for each batch
55
+ self.cache_kwargs: Dict[str, List[Any]] = {}
56
+ self.num_batches: int = 0
57
+
58
+ # References to original forwards for restoration
59
+ self._orig_model_forward: Optional[Callable[..., Any]] = None
60
+ self._orig_layer_forward: Optional[Callable[..., Any]] = None
61
+ self._first_layer_ref: Optional[torch.nn.Module] = None
62
+
63
+ @torch.no_grad()
64
+ def prepare(
65
+ self,
66
+ model: torch.nn.Module,
67
+ args: Optional[Any] = None,
68
+ kwargs: Optional[Dict[str, Any]] = None,
69
+ ):
70
+ """
71
+ Overrides the forward method of the first LLaMA layer (layer 0) to capture the
72
+ input required for calibration.
73
+
74
+ When the user calls `model(...)`, we intercept (and store) the inputs to that
75
+ layer, then raise an exception to stop the forward pass immediately. These
76
+ captured inputs are then utilized to calibrate the quantization parameters
77
+ for the GPTQ.
78
+
79
+ Parameters:
80
+ model (torch.nn.Module): The target PyTorch model
81
+ args (Any, optional): Unused (kept for API compatibility)
82
+ kwargs (Dict[str, Any], optional): Unused (kept for API compatibility)
83
+
84
+ Returns:
85
+ torch.nn.Module: The model with the catcher attached
86
+ """
87
+ # Define the catcher to store inputs/kwargs and stop the execution
88
+ def forward(layer, *args, **kwargs):
89
+ """
90
+ Stores this batch's inputs and kwargs, then raises StopForward to stop computation.
91
+ """
92
+ # Store positional args
93
+ for idx, item in enumerate(args):
94
+ if (idx + 1) > len(self.cache_args):
95
+ self.cache_args.append([])
96
+ self.cache_args[idx].append(item)
97
+ # Store keyword args
98
+ for k, v in kwargs.items():
99
+ if self.cache_kwargs.get(k, None) is None:
100
+ self.cache_kwargs[k] = []
101
+ self.cache_kwargs[k].append(v)
102
+
103
+ self.num_batches += 1
104
+ raise StopForward # stop after the first layer
105
+
106
+ # Replace the first layer with defined function to capture calibration data.
107
+ if hasattr(model, "model"):
108
+ if hasattr(model.model, "layers") and isinstance(
109
+ model.model.layers, torch.nn.ModuleList
110
+ ):
111
+ self._first_layer_ref = model.model.layers[0]
112
+ else:
113
+ raise RuntimeError(
114
+ "GPTQ Quantizer assumes the model has a nested structure like `model.model.layers`, commonly found in LLaMA and other Hugging Face transformer models."
115
+ )
116
+ else:
117
+ # fallback if the model is not LLaMA-like; treat whole model as single layer
118
+ self._first_layer_ref = model
119
+
120
+ assert hasattr(self._first_layer_ref, "forward")
121
+ # Backup the original forward of the first layer
122
+ assert isinstance(self._first_layer_ref, torch.nn.Module)
123
+ self._orig_layer_forward = self._first_layer_ref.forward
124
+ self._first_layer_ref.forward = types.MethodType(forward, self._first_layer_ref)
125
+
126
+ def model_forward_wrapper(_model, *m_args, **m_kwargs):
127
+ """
128
+ Wrapper to ignore StopForward exceptions so the user's training loop doesn't crash.
129
+ """
130
+ try:
131
+ assert self._orig_model_forward is not None
132
+ return self._orig_model_forward(*m_args, **m_kwargs)
133
+ except StopForward:
134
+ # We stopped after the first layer; return None or dummy output if needed.
135
+ return None
136
+
137
+ # Backup model.forward so we can suppress StopForward
138
+ self._orig_model_forward = model.forward
139
+ model.forward = types.MethodType(model_forward_wrapper, model)
140
+
141
+ return model
142
+
143
+ @torch.no_grad()
144
+ def convert(self, model):
145
+ """
146
+ Perform GPTQ quantization using cached first-layer inputs.
147
+
148
+ Steps:
149
+ 1) Restore original forwards (no more catching).
150
+ 2) Iterate through each Transformer layer sequentially:
151
+ a) For each layer, register forward hooks to collect (inp, out) stats for GPTQ.
152
+ b) Run the layer on cached inputs for all batches.
153
+ c) Apply GPTQ and update the weights.
154
+ d) Re-run the layer to produce outputs for the next layer; update cached inputs.
155
+ 3) Restore model.config.use_cache if needed and clear internal caches.
156
+
157
+ Parameters:
158
+ model (torch.nn.Module): The prepared model.
159
+
160
+ Returns:
161
+ torch.nn.Module: Quantized model.
162
+ """
163
+ # Restore original forwards (we no longer want to stop after first layer)
164
+ assert self._orig_model_forward is not None
165
+ model.forward = self._orig_model_forward
166
+ assert (
167
+ self._first_layer_ref is not None and self._orig_layer_forward is not None
168
+ )
169
+ self._first_layer_ref.forward = self._orig_layer_forward
170
+
171
+ gptq_conf = self.config
172
+ assert isinstance(gptq_conf, GPTQConfig)
173
+ # Disable use_cache during calibration
174
+ if hasattr(model, "config") and hasattr(model.config, "use_cache"):
175
+ orig_use_cache = model.config.use_cache
176
+ model.config.use_cache = False
177
+ else:
178
+ orig_use_cache = None
179
+
180
+ # Identify layers
181
+ if hasattr(model, "model"):
182
+ target_layers = model.model.layers
183
+ else:
184
+ target_layers = [model]
185
+
186
+ quantizers: Dict[str, Any] = {}
187
+ for l_idx, layer in enumerate(
188
+ tqdm(
189
+ target_layers,
190
+ desc="Quantizing layers",
191
+ unit="layer",
192
+ disable=not gptq_conf.show_progress,
193
+ )
194
+ ):
195
+ # 1) Identify quantizable submodules within the layer
196
+ full = find_layers(layer)
197
+ sequential = [list(full.keys())]
198
+
199
+ # 2) Set up GPTQ objects and gather stats
200
+ for names in sequential:
201
+ subset = {n: full[n] for n in names}
202
+
203
+ gptq: Dict[str, GPTQ] = {}
204
+ for name in subset:
205
+ gptq[name] = GPTQ(subset[name])
206
+ gptq[name].quantizer.configure(
207
+ bits=8, perchannel=True, sym=False, mse=False
208
+ )
209
+
210
+ # Hook to collect (inp, out) for GPTQ
211
+ def add_batch(name):
212
+ def _hook(_, inp, out):
213
+ gptq[name].add_batch(inp[0].data, out.data)
214
+
215
+ return _hook
216
+
217
+ handles = []
218
+ for name in subset:
219
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
220
+
221
+ # Run layer forward over all cached batches to build Hessian/statistics
222
+ batch_num = self.num_batches
223
+ for batch_idx in tqdm(
224
+ range(batch_num),
225
+ desc=f"[L{l_idx}] collecting",
226
+ leave=False,
227
+ unit="batch",
228
+ disable=not gptq_conf.show_progress,
229
+ ):
230
+ cache_args_batch = gather_single_batch_from_list(
231
+ self.cache_args, batch_idx
232
+ )
233
+ cache_kwargs_batch = gather_single_batch_from_dict(
234
+ self.cache_kwargs, batch_idx
235
+ )
236
+ layer(*cache_args_batch, **cache_kwargs_batch)
237
+
238
+ # Remove handles
239
+ for h in handles:
240
+ h.remove()
241
+
242
+ # 3) Quantize each submodule
243
+ for name in subset:
244
+ if gptq_conf.verbose:
245
+ print(f"[Layer {l_idx}] {name} -> Quantizing ...")
246
+ gptq[name].fasterquant(
247
+ percdamp=0.01,
248
+ groupsize=-1,
249
+ actorder=True,
250
+ static_groups=False,
251
+ verbose=gptq_conf.verbose,
252
+ )
253
+ quantizers[f"model.layers.{l_idx}.{name}"] = gptq[name].quantizer
254
+ gptq[name].free()
255
+
256
+ # 4) After quantization, re-run the layer to produce outputs for the next layer
257
+ for batch_idx in tqdm(
258
+ range(batch_num),
259
+ desc=f"[L{l_idx}] re-forward",
260
+ leave=False,
261
+ unit="batch",
262
+ disable=not gptq_conf.show_progress,
263
+ ):
264
+ cache_args_batch = gather_single_batch_from_list(
265
+ self.cache_args, batch_idx
266
+ )
267
+ cache_kwargs_batch = gather_single_batch_from_dict(
268
+ self.cache_kwargs, batch_idx
269
+ )
270
+ outs = layer(*cache_args_batch, **cache_kwargs_batch)
271
+ # LLaMA's decoder layer return type differs across Transformers versions:
272
+ # some return a tuple (hidden_states, ...), others return just a tensor.
273
+ # This line ensures we always take the first element when it's a tuple.
274
+ outs = outs[0] if isinstance(outs, tuple) else outs
275
+ # Update inputs for next iteration.
276
+ self.cache_args[0][batch_idx] = outs
277
+
278
+ if torch.cuda.is_available():
279
+ torch.cuda.empty_cache()
280
+
281
+ # Restore the original cache configuration.
282
+ if orig_use_cache is not None:
283
+ model.config.use_cache = orig_use_cache
284
+
285
+ # Clear caches to free memory
286
+ self.cache_args.clear()
287
+ self.cache_kwargs.clear()
288
+ self.num_batches = 0
289
+
290
+ model.quantizers = quantizers
291
+
292
+ return model
@@ -58,7 +58,7 @@ def gather_single_batch_from_list(data_list, idx):
58
58
  Returns:
59
59
  list: single batch.
60
60
  """
61
- # obtain a set of keyword input from cache
61
+ # obtain a set of positional input from cache
62
62
  single_batch = []
63
63
  for data_item in data_list:
64
64
  single_batch.append(data_item[idx])
@@ -21,23 +21,16 @@ if TYPE_CHECKING:
21
21
  import torch.fx
22
22
  from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
23
23
  import torch
24
- from torch.ao.quantization.observer import (
25
- MinMaxObserver,
26
- MovingAverageMinMaxObserver,
27
- MovingAveragePerChannelMinMaxObserver,
28
- PerChannelMinMaxObserver,
29
- )
24
+ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
30
25
  from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
31
26
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
32
27
 
33
- from tico.experimental.quantization.algorithm.pt2e.annotation.op import *
34
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
35
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
36
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
37
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
38
- QuantizationConfig,
39
- )
40
- from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
28
+ from tico.quantization.algorithm.pt2e.annotation.op import *
29
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
30
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
31
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
32
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
33
+ from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
41
34
  convert_scalars_to_attrs,
42
35
  )
43
36
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import SharedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import AddTensorArgs
28
26
 
29
27
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import Conv2DArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import DivTensorArgs
28
26
 
29
27
 
@@ -12,19 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, List, Optional, TYPE_CHECKING
15
+ from typing import Callable, Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import LinearArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MeanDimArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MulTensorArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import Relu6Args
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import RsqrtArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import SubTensorArgs
28
26
 
29
27
 
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
22
- QuantizationConfig,
23
- )
21
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
24
22
 
25
23
  AnnotatorType = Callable[
26
24
  [
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
22
22
  SharedQuantizationSpec,
23
23
  )
24
24
 
25
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
25
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
26
26
 
27
27
 
28
28
  def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):
@@ -18,13 +18,16 @@ import torch
18
18
 
19
19
  from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.annotator import (
21
+ from tico.quantization.algorithm.pt2e.annotation.annotator import (
22
22
  get_asymmetric_quantization_config,
23
23
  PT2EAnnotator,
24
24
  )
25
- from tico.experimental.quantization.quantizer import BaseQuantizer
25
+ from tico.quantization.config.pt2e import PT2EConfig
26
+ from tico.quantization.quantizer import BaseQuantizer
27
+ from tico.quantization.quantizer_registry import register_quantizer
26
28
 
27
29
 
30
+ @register_quantizer(PT2EConfig)
28
31
  class PT2EQuantizer(BaseQuantizer):
29
32
  """
30
33
  Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
@@ -19,11 +19,8 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import QuantizationSpec
21
21
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
22
- from torch.utils import _pytree as pytree
23
22
 
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
23
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
24
 
28
25
 
29
26
  def get_module_type_filter(tp: Callable):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import functools
16
- from typing import Any, Dict, List
16
+ from typing import Any, Dict, List, Literal
17
17
 
18
18
  import torch
19
19
 
@@ -21,18 +21,24 @@ import torch
21
21
  class ChannelwiseMaxActsObserver:
22
22
  """
23
23
  Observer to calcuate channelwise maximum activation
24
+ It supports collecting activations from either module inputs or outputs.
24
25
  """
25
26
 
26
- def __init__(self, model):
27
+ def __init__(
28
+ self, model: torch.nn.Module, acts_from: Literal["input", "output"] = "input"
29
+ ):
27
30
  """
28
31
  model
29
32
  A torch module whose activations are to be analyzed.
33
+ acts_from
34
+ Where to hook: "input" for forward-pre-hook, "output" for forward-hook.
30
35
  hooks
31
- A list to store the hooks which are registered to collect activation statistics.
36
+ A list to store the hooks registered to collect activation statistics.
32
37
  max_acts
33
- A dictionary to store the maximum activation values
38
+ A dictionary to store the per-channel maxima.
34
39
  """
35
40
  self.model = model
41
+ self.acts_from: Literal["input", "output"] = acts_from
36
42
  self.hooks: List[Any] = []
37
43
  self.max_acts: Dict[str, torch.Tensor] = {}
38
44
 
@@ -62,13 +68,25 @@ class ChannelwiseMaxActsObserver:
62
68
  input = input[0]
63
69
  stat_tensor(name, input)
64
70
 
71
+ def stat_output_hook(m, input, output, name):
72
+ if isinstance(output, tuple):
73
+ output = output[0]
74
+ stat_tensor(name, output)
75
+
65
76
  for name, m in self.model.named_modules():
66
77
  if isinstance(m, torch.nn.Linear):
67
- self.hooks.append(
68
- m.register_forward_pre_hook(
69
- functools.partial(stat_input_hook, name=name)
78
+ if self.acts_from == "input":
79
+ self.hooks.append(
80
+ m.register_forward_pre_hook(
81
+ functools.partial(stat_input_hook, name=name)
82
+ )
83
+ )
84
+ else: # "output"
85
+ self.hooks.append(
86
+ m.register_forward_hook(
87
+ functools.partial(stat_output_hook, name=name)
88
+ )
70
89
  )
71
- )
72
90
 
73
91
  def remove(self):
74
92
  for hook in self.hooks: