tico 0.1.0.dev250924__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (114) hide show
  1. tico/__init__.py +1 -1
  2. tico/quantization/__init__.py +6 -0
  3. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  4. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  5. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  6. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +12 -6
  7. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  8. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  9. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  10. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  11. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  12. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  13. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +4 -4
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  22. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +6 -10
  23. tico/quantization/config/fpi_gptq.py +29 -0
  24. tico/{experimental/quantization → quantization}/config/gptq.py +1 -1
  25. tico/{experimental/quantization → quantization}/config/pt2e.py +1 -1
  26. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  27. tico/{experimental/quantization → quantization}/config/smoothquant.py +1 -1
  28. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  29. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +1 -3
  30. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  31. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  32. tico/{experimental/quantization → quantization}/public_interface.py +7 -7
  33. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  34. tico/{experimental/quantization → quantization}/quantizer_registry.py +11 -10
  35. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/compare_ppl.py +8 -19
  36. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/debug_quant_outputs.py +9 -24
  37. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  38. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  39. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  40. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  41. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_with_gptq.py +14 -35
  42. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  43. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  44. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  45. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  46. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  47. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  48. tico/quantization/wrapq/quantizer.py +179 -0
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/__init__.py +1 -1
  52. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder.py +6 -8
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder_layer.py +6 -8
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder.py +6 -8
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder_layer.py +6 -8
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_mha.py +5 -7
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +5 -7
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +8 -12
  59. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  60. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  62. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  63. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  64. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  65. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  66. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  67. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -16
  68. tico/utils/convert.py +9 -14
  69. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  70. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +113 -108
  71. tico/experimental/quantization/__init__.py +0 -6
  72. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  73. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  74. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  75. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  76. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  77. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  78. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  79. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  80. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  81. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  82. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  83. /tico/{experimental/quantization/config → quantization/algorithm/smoothquant}/__init__.py +0 -0
  84. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +0 -0
  85. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/smooth_quant.py +0 -0
  86. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  87. /tico/{experimental/quantization → quantization}/config/base.py +0 -0
  88. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  89. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  90. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  91. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  92. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  93. /tico/{experimental/quantization/ptq → quantization/passes}/__init__.py +0 -0
  94. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  95. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  96. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  97. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  98. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  99. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  100. /tico/{experimental/quantization/ptq/examples → quantization/wrapq}/__init__.py +0 -0
  101. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  102. /tico/{experimental/quantization/ptq/observers → quantization/wrapq/examples}/__init__.py +0 -0
  103. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  104. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/observers}/__init__.py +0 -0
  105. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  106. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/utils}/__init__.py +0 -0
  107. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  108. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/wrappers}/__init__.py +0 -0
  109. /tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/decoder_export_single_step.py +0 -0
  110. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers/llama}/__init__.py +0 -0
  111. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  112. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  113. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  114. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250924"
32
+ __version__ = "0.1.0.dev251109"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,6 @@
1
+ from tico.quantization.public_interface import convert, prepare
2
+
3
+ __all__ = [
4
+ "convert",
5
+ "prepare",
6
+ ]
@@ -0,0 +1,161 @@
1
+ # Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
2
+ # Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
3
+ # Apache License 2.0.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # https://github.com/IST-DASLab/gptq/blob/2d65066/gptq.py
20
+
21
+ import math
22
+ import time
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
29
+
30
+
31
+ def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
32
+
33
+ cur_weights = W.clone()
34
+ mults = torch.pow(torch.diag(Hinv), -1)
35
+ Hinv_U = torch.triu(Hinv, diagonal=1)
36
+
37
+ init_weights = W.clone()
38
+ for _ in range(max_num_of_iters):
39
+ cur_Q = quantize(cur_weights, scale, zero, maxq)
40
+
41
+ d_W = torch.mul((cur_weights - cur_Q), mults)
42
+ cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
43
+ del d_W, cur_Q
44
+ d_W = cur_Q = None
45
+
46
+ del init_weights
47
+ init_weights = None
48
+
49
+ cur_Q = quantize(cur_weights, scale, zero, maxq)
50
+
51
+ return cur_Q, cur_weights
52
+
53
+
54
+ class FPI_GPTQ:
55
+ def __init__(self, layer):
56
+ self.layer = layer
57
+ self.dev = self.layer.weight.device
58
+ W = layer.weight.data.clone()
59
+ if isinstance(self.layer, nn.Conv2d):
60
+ W = W.flatten(1)
61
+
62
+ if isinstance(self.layer, nn.Conv1d):
63
+ W = W.t()
64
+ self.rows = W.shape[0]
65
+ self.columns = W.shape[1]
66
+ self.H: Optional[torch.Tensor] = torch.zeros(
67
+ (self.columns, self.columns), device=self.dev
68
+ )
69
+ self.nsamples = 0
70
+ self.quantizer: Quantizer = Quantizer()
71
+
72
+ def add_batch(self, inp, out):
73
+ if len(inp.shape) == 2:
74
+ inp = inp.unsqueeze(0)
75
+ tmp = inp.shape[0]
76
+ if isinstance(self.layer, nn.Linear) or isinstance(self.layer, nn.Conv1d):
77
+ if len(inp.shape) > 2:
78
+ inp = inp.reshape((-1, inp.shape[-1]))
79
+ inp = inp.t()
80
+ if isinstance(self.layer, nn.Conv2d):
81
+ unfold = nn.Unfold(
82
+ self.layer.kernel_size,
83
+ dilation=self.layer.dilation,
84
+ padding=self.layer.padding,
85
+ stride=self.layer.stride,
86
+ )
87
+
88
+ inp = unfold(inp)
89
+ inp = inp.permute([1, 0, 2])
90
+ inp = inp.flatten(1)
91
+
92
+ self.H *= self.nsamples / (self.nsamples + tmp)
93
+ self.nsamples += tmp
94
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
95
+ self.H += inp.matmul(inp.t())
96
+
97
+ def fasterquant(
98
+ self,
99
+ percdamp=0.01,
100
+ verbose=False,
101
+ ):
102
+ W = self.layer.weight.data.clone()
103
+ if isinstance(self.layer, nn.Conv2d):
104
+ W = W.flatten(1)
105
+ if isinstance(self.layer, nn.Conv1d):
106
+ W = W.t()
107
+ W = W.float()
108
+ tick = time.time()
109
+ if not self.quantizer.ready():
110
+ self.quantizer.find_params(W, weight=True)
111
+
112
+ H = self.H
113
+ del self.H
114
+ assert isinstance(H, torch.Tensor)
115
+ dead = torch.diag(H) == 0
116
+ H[dead, dead] = 1
117
+ W[:, dead] = 0
118
+
119
+ # actorder
120
+ perm = torch.argsort(torch.diag(H), descending=True)
121
+ W = W[:, perm]
122
+ H = H[perm][:, perm]
123
+ invperm = torch.argsort(perm)
124
+
125
+ Q = torch.zeros_like(W)
126
+
127
+ damp = percdamp * torch.mean(torch.diag(H))
128
+ diag = torch.arange(self.columns, device=self.dev)
129
+ H[diag, diag] += damp
130
+ H = torch.linalg.cholesky(H)
131
+ assert isinstance(H, torch.Tensor)
132
+ H = torch.cholesky_inverse(H)
133
+ H = torch.linalg.cholesky(H, upper=True)
134
+ Hinv = H
135
+
136
+ Q, W = iterate_GPTQ(
137
+ self.quantizer.scale,
138
+ self.quantizer.zero,
139
+ self.quantizer.maxq,
140
+ W,
141
+ Hinv=Hinv,
142
+ max_num_of_iters=50,
143
+ )
144
+
145
+ if torch.cuda.is_available():
146
+ torch.cuda.synchronize()
147
+ if verbose:
148
+ print("time %.2f" % (time.time() - tick))
149
+ Losses = 0.5 * ((Q - W) / torch.diag(Hinv)) ** 2
150
+ print("error", torch.sum(Losses).item())
151
+
152
+ Q = Q[:, invperm]
153
+
154
+ self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
155
+ self.layer.weight.data.dtype
156
+ )
157
+
158
+ def free(self):
159
+ self.H = None
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2024 Intel Corporation
2
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict
17
+
18
+ import torch
19
+ from tqdm.auto import tqdm
20
+
21
+ from tico.quantization.algorithm.fpi_gptq.fpi_gptq import FPI_GPTQ
22
+ from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
23
+ from tico.quantization.algorithm.gptq.utils import (
24
+ find_layers,
25
+ gather_single_batch_from_dict,
26
+ gather_single_batch_from_list,
27
+ )
28
+ from tico.quantization.config.fpi_gptq import FPIGPTQConfig
29
+ from tico.quantization.quantizer_registry import register_quantizer
30
+
31
+
32
+ @register_quantizer(FPIGPTQConfig)
33
+ class FPIGPTQQuantizer(GPTQQuantizer):
34
+ """
35
+ Quantizer for applying the Fixed Point Iteration GPTQ algorithm (FPIGPTQ)
36
+ This implementation expects the same steps as GPTQQuantizer.
37
+ It should produce results very close to reference GPTQ but much faster when running on cuda.
38
+ """
39
+
40
+ def __init__(self, config: FPIGPTQConfig):
41
+ super().__init__(config)
42
+
43
+ @torch.no_grad()
44
+ def convert(self, model):
45
+
46
+ # Restore original forwards (we no longer want to stop after first layer)
47
+ assert self._orig_model_forward is not None
48
+ model.forward = self._orig_model_forward
49
+ assert (
50
+ self._first_layer_ref is not None and self._orig_layer_forward is not None
51
+ )
52
+ self._first_layer_ref.forward = self._orig_layer_forward
53
+
54
+ gptq_conf = self.config
55
+ assert isinstance(gptq_conf, FPIGPTQConfig)
56
+ # Disable use_cache during calibration
57
+ if hasattr(model, "config") and hasattr(model.config, "use_cache"):
58
+ orig_use_cache = model.config.use_cache
59
+ model.config.use_cache = False
60
+ else:
61
+ orig_use_cache = None
62
+
63
+ # Identify layers
64
+ if hasattr(model, "model"):
65
+ target_layers = model.model.layers
66
+ else:
67
+ target_layers = [model]
68
+
69
+ quantizers: Dict[str, Any] = {}
70
+ for l_idx, layer in enumerate(
71
+ tqdm(
72
+ target_layers,
73
+ desc="Quantizing layers",
74
+ unit="layer",
75
+ disable=not gptq_conf.show_progress,
76
+ )
77
+ ):
78
+ # 1) Identify quantizable submodules within the layer
79
+ full = find_layers(layer, layers=[torch.nn.Linear, torch.nn.Conv2d])
80
+ # filter out depthwise convolutions and alike
81
+ full = {
82
+ key: full[key]
83
+ for key in full.keys()
84
+ if not isinstance(full[key], torch.nn.Conv2d) or full[key].groups == 1
85
+ }
86
+
87
+ sequential = [list(full.keys())]
88
+
89
+ # 2) Set up (as in GPTQ)
90
+ for names in sequential:
91
+ subset = {n: full[n] for n in names}
92
+
93
+ gptq: Dict[str, FPI_GPTQ] = {}
94
+ for name in subset:
95
+ gptq[name] = FPI_GPTQ(subset[name])
96
+ gptq[name].quantizer.configure(
97
+ bits=8, perchannel=True, sym=False, mse=False
98
+ )
99
+
100
+ # Hook to collect (inp, out) for GPTQ
101
+ def add_batch(name):
102
+ def _hook(_, inp, out):
103
+ gptq[name].add_batch(inp[0].data, out.data)
104
+
105
+ return _hook
106
+
107
+ handles = []
108
+ for name in subset:
109
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
110
+
111
+ # Run layer forward over all cached batches to build Hessian/statistics
112
+ batch_num = self.num_batches
113
+ for batch_idx in tqdm(
114
+ range(batch_num),
115
+ desc=f"[L{l_idx}] collecting",
116
+ leave=False,
117
+ unit="batch",
118
+ disable=not gptq_conf.show_progress,
119
+ ):
120
+ cache_args_batch = gather_single_batch_from_list(
121
+ self.cache_args, batch_idx
122
+ )
123
+ cache_kwargs_batch = gather_single_batch_from_dict(
124
+ self.cache_kwargs, batch_idx
125
+ )
126
+ layer(*cache_args_batch, **cache_kwargs_batch)
127
+
128
+ # Remove handles
129
+ for h in handles:
130
+ h.remove()
131
+
132
+ # 3) Quantize each submodule
133
+ for name in subset:
134
+ if gptq_conf.verbose:
135
+ print(f"[Layer {l_idx}] {name} -> Quantizing ...")
136
+ gptq[name].fasterquant(
137
+ percdamp=0.01,
138
+ verbose=gptq_conf.verbose,
139
+ )
140
+ quantizers[f"model.layers.{l_idx}.{name}"] = gptq[name].quantizer
141
+ gptq[name].free()
142
+
143
+ # 4) After quantization, re-run the layer to produce outputs for the next layer
144
+ for batch_idx in tqdm(
145
+ range(batch_num),
146
+ desc=f"[L{l_idx}] re-forward",
147
+ leave=False,
148
+ unit="batch",
149
+ disable=not gptq_conf.show_progress,
150
+ ):
151
+ cache_args_batch = gather_single_batch_from_list(
152
+ self.cache_args, batch_idx
153
+ )
154
+ cache_kwargs_batch = gather_single_batch_from_dict(
155
+ self.cache_kwargs, batch_idx
156
+ )
157
+ outs = layer(*cache_args_batch, **cache_kwargs_batch)
158
+ # LLaMA's decoder layer return type differs across Transformers versions:
159
+ # some return a tuple (hidden_states, ...), others return just a tensor.
160
+ # This line ensures we always take the first element when it's a tuple.
161
+ outs = outs[0] if isinstance(outs, tuple) else outs
162
+ # Update inputs for next iteration.
163
+ self.cache_args[0][batch_idx] = outs
164
+
165
+ if torch.cuda.is_available():
166
+ torch.cuda.empty_cache()
167
+
168
+ # Restore the original cache configuration.
169
+ if orig_use_cache is not None:
170
+ model.config.use_cache = orig_use_cache
171
+
172
+ # Clear caches to free memory
173
+ self.cache_args.clear()
174
+ self.cache_kwargs.clear()
175
+ self.num_batches = 0
176
+
177
+ model.quantizers = quantizers
178
+
179
+ return model
@@ -25,7 +25,7 @@ from typing import Optional
25
25
  import torch
26
26
  import torch.nn as nn
27
27
 
28
- from tico.experimental.quantization.algorithm.gptq.quant import quantize, Quantizer
28
+ from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
29
29
 
30
30
  torch.backends.cuda.matmul.allow_tf32 = False
31
31
  torch.backends.cudnn.allow_tf32 = False
@@ -36,6 +36,11 @@ class GPTQ:
36
36
  self.layer = layer
37
37
  self.dev = self.layer.weight.device
38
38
  W = layer.weight.data.clone()
39
+ if isinstance(self.layer, nn.Conv2d):
40
+ W = W.flatten(1)
41
+
42
+ if isinstance(self.layer, nn.Conv1d):
43
+ W = W.t()
39
44
  self.rows = W.shape[0]
40
45
  self.columns = W.shape[1]
41
46
  self.H: Optional[torch.Tensor] = torch.zeros(
@@ -48,10 +53,22 @@ class GPTQ:
48
53
  if len(inp.shape) == 2:
49
54
  inp = inp.unsqueeze(0)
50
55
  tmp = inp.shape[0]
51
- if isinstance(self.layer, nn.Linear):
52
- if len(inp.shape) == 3:
56
+ if isinstance(self.layer, nn.Linear) or isinstance(self.layer, nn.Conv1d):
57
+ if len(inp.shape) > 2:
53
58
  inp = inp.reshape((-1, inp.shape[-1]))
54
59
  inp = inp.t()
60
+ if isinstance(self.layer, nn.Conv2d):
61
+ unfold = nn.Unfold(
62
+ self.layer.kernel_size,
63
+ dilation=self.layer.dilation,
64
+ padding=self.layer.padding,
65
+ stride=self.layer.stride,
66
+ )
67
+
68
+ inp = unfold(inp)
69
+ inp = inp.permute([1, 0, 2])
70
+ inp = inp.flatten(1)
71
+
55
72
  self.H *= self.nsamples / (self.nsamples + tmp)
56
73
  self.nsamples += tmp
57
74
  inp = math.sqrt(2 / self.nsamples) * inp.float()
@@ -67,6 +84,10 @@ class GPTQ:
67
84
  verbose=False,
68
85
  ):
69
86
  W = self.layer.weight.data.clone()
87
+ if isinstance(self.layer, nn.Conv2d):
88
+ W = W.flatten(1)
89
+ if isinstance(self.layer, nn.Conv1d):
90
+ W = W.t()
70
91
  W = W.float()
71
92
  tick = time.time()
72
93
  if not self.quantizer.ready():
@@ -19,15 +19,15 @@ from typing import Any, Callable, Dict, List, Optional
19
19
  import torch
20
20
  from tqdm.auto import tqdm
21
21
 
22
- from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
23
- from tico.experimental.quantization.algorithm.gptq.utils import (
22
+ from tico.quantization.algorithm.gptq.gptq import GPTQ
23
+ from tico.quantization.algorithm.gptq.utils import (
24
24
  find_layers,
25
25
  gather_single_batch_from_dict,
26
26
  gather_single_batch_from_list,
27
27
  )
28
- from tico.experimental.quantization.config.gptq import GPTQConfig
29
- from tico.experimental.quantization.quantizer import BaseQuantizer
30
- from tico.experimental.quantization.quantizer_registry import register_quantizer
28
+ from tico.quantization.config.gptq import GPTQConfig
29
+ from tico.quantization.quantizer import BaseQuantizer
30
+ from tico.quantization.quantizer_registry import register_quantizer
31
31
 
32
32
 
33
33
  class StopForward(Exception):
@@ -193,7 +193,13 @@ class GPTQQuantizer(BaseQuantizer):
193
193
  )
194
194
  ):
195
195
  # 1) Identify quantizable submodules within the layer
196
- full = find_layers(layer)
196
+ full = find_layers(layer, layers=[torch.nn.Linear, torch.nn.Conv2d])
197
+ # filter out depthwise convolutions and alike
198
+ full = {
199
+ key: full[key]
200
+ for key in full.keys()
201
+ if not isinstance(full[key], torch.nn.Conv2d) or full[key].groups == 1
202
+ }
197
203
  sequential = [list(full.keys())]
198
204
 
199
205
  # 2) Set up GPTQ objects and gather stats
@@ -25,14 +25,12 @@ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObser
25
25
  from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
26
26
  from torch.ao.quantization.quantizer.utils import _get_module_name_filter
27
27
 
28
- from tico.experimental.quantization.algorithm.pt2e.annotation.op import *
29
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
30
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
31
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
32
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
33
- QuantizationConfig,
34
- )
35
- from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
28
+ from tico.quantization.algorithm.pt2e.annotation.op import *
29
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
30
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
31
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
32
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
33
+ from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
36
34
  convert_scalars_to_attrs,
37
35
  )
38
36
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import SharedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import AddTensorArgs
28
26
 
29
27
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import Conv2DArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import DivTensorArgs
28
26
 
29
27
 
@@ -19,12 +19,10 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.ao.quantization.quantizer import DerivedQuantizationSpec
21
21
 
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
- QuantizationConfig,
27
- )
22
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
28
26
  from tico.utils.validate_args_kwargs import LinearArgs
29
27
 
30
28
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MeanDimArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import MulTensorArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import Relu6Args
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import RsqrtArgs
28
26
 
29
27
 
@@ -18,12 +18,10 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
- import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
- import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
- QuantizationConfig,
26
- )
21
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
27
25
  from tico.utils.validate_args_kwargs import SubTensorArgs
28
26
 
29
27
 
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import torch
20
20
 
21
- from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
22
- QuantizationConfig,
23
- )
21
+ from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
24
22
 
25
23
  AnnotatorType = Callable[
26
24
  [
@@ -22,7 +22,7 @@ from torch.ao.quantization.quantizer import (
22
22
  SharedQuantizationSpec,
23
23
  )
24
24
 
25
- import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
25
+ import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
26
26
 
27
27
 
28
28
  def annotate_input_qspec_map(node: torch.fx.Node, input_node: torch.fx.Node, qspec):