tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,225 @@
1
+ # Copyright (c) 2024 Intel Corporation
2
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import types
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ import torch
20
+
21
+ from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
22
+ from tico.experimental.quantization.algorithm.gptq.utils import (
23
+ find_layers,
24
+ gather_single_batch_from_dict,
25
+ gather_single_batch_from_list,
26
+ )
27
+ from tico.experimental.quantization.config import BaseConfig, GPTQConfig
28
+ from tico.experimental.quantization.quantizer import BaseQuantizer
29
+
30
+
31
+ class GPTQQuantizer(BaseQuantizer):
32
+ """
33
+ Quantizer for applying the GPTQ algorithm (typically for weight quantization)
34
+ """
35
+
36
+ def __init__(self, config: BaseConfig):
37
+ super().__init__(config)
38
+
39
+ self.cache_args: List[Any] = []
40
+ self.cache_kwargs: Dict[str, Any] = {"batch_num": 0}
41
+
42
+ @torch.no_grad()
43
+ def prepare(
44
+ self,
45
+ model: torch.nn.Module,
46
+ args: Optional[Any] = None,
47
+ kwargs: Optional[Dict[str, Any]] = None,
48
+ ):
49
+ """
50
+ Overrides the forward method of the first LLaMA layer (layer 0) to capture the
51
+ input required for calibration.
52
+
53
+ This method modifies the original forward pass of LLaMA layer 0 so that the
54
+ inputs used during inference are intercepted and recorded. These captured inputs
55
+ are then utilized to calibrate the quantization parameters for the GPTQ.
56
+
57
+ Parameters:
58
+ model: The target PyTorch model.
59
+ args: Positional example inputs required for capturing graph.
60
+ kwargs: Keyword example inputs required for capturing graph.
61
+
62
+ Returns:
63
+ The model prepared for GPTQ quantization.
64
+ """
65
+ if args is None and kwargs is None:
66
+ raise RuntimeError(
67
+ "Either args or kwargs must be provided for captruing graph."
68
+ )
69
+ # Define a function to capture input activations and associated parameters.
70
+ def forward(layer, *args, **kwargs):
71
+ self.cache_kwargs["batch_num"] += 1
72
+ for idx, item in enumerate(args):
73
+ if (idx + 1) > len(self.cache_args):
74
+ self.cache_args.append([])
75
+ self.cache_args[idx].append(item)
76
+ for arg in kwargs:
77
+ if self.cache_kwargs.get(arg, None) is None:
78
+ self.cache_kwargs[arg] = []
79
+ self.cache_kwargs[arg].append(kwargs[arg])
80
+ # Raise an error to interrupt the forward pass after capturing data.
81
+ raise ValueError
82
+
83
+ # Replace the first layer with defined function to capture calibration data.
84
+ if hasattr(model, "model"):
85
+ assert hasattr(model.model, "layers")
86
+ assert isinstance(model.model.layers, torch.nn.ModuleList)
87
+ layer_forward_cache = model.model.layers[0].forward
88
+ model.model.layers[0].forward = types.MethodType(
89
+ forward, model.model.layers[0]
90
+ )
91
+ else:
92
+ assert hasattr(model, "forward")
93
+ layer_forward_cache = model.forward
94
+ model.forward = types.MethodType(forward, model.forward)
95
+
96
+ model_forward_cache = model.forward
97
+ # Replace model's forward to avoid ValueError
98
+ def model_forward(model, *args, **kwargs):
99
+ nonlocal model_forward_cache
100
+ try:
101
+ model_forward_cache(*args, **kwargs)
102
+ except ValueError:
103
+ pass
104
+
105
+ model.forward = types.MethodType(model_forward, model)
106
+ kwargs = kwargs or {}
107
+ model(*args, **kwargs) # type: ignore[misc]
108
+
109
+ # Recover original forward
110
+ model.forward = model_forward_cache
111
+ if hasattr(model, "model"):
112
+ assert hasattr(model.model, "layers")
113
+ assert isinstance(model.model.layers, torch.nn.ModuleList)
114
+ model.model.layers[0].forward = layer_forward_cache
115
+ else:
116
+ model.forward = layer_forward_cache
117
+
118
+ return model
119
+
120
+ @torch.no_grad()
121
+ def convert(self, model):
122
+ """
123
+ Convert the prepared model to its GPTQ quantized version.
124
+
125
+ Applies the GPTQ quantization on weights based on the collected statistics.
126
+
127
+ Parameters:
128
+ model: The prepared PyTorch model.
129
+
130
+ Returns:
131
+ The quantized model.
132
+ """
133
+ gptq_conf = self.config
134
+ assert isinstance(gptq_conf, GPTQConfig)
135
+
136
+ # Save the original cache setting and disable caching during calibration/inference.
137
+ if hasattr(model, "config"):
138
+ use_cache = model.config.use_cache
139
+ model.config.use_cache = False
140
+
141
+ quantizers = {}
142
+ if hasattr(model, "model"):
143
+ target_layers = model.model.layers
144
+ else:
145
+ target_layers = [model]
146
+ for l_idx, layer in enumerate(target_layers):
147
+ # Identify quantizable submodules within the layer.
148
+ full = find_layers(layer)
149
+
150
+ sequential = [list(full.keys())]
151
+ for names in sequential:
152
+ subset = {n: full[n] for n in names}
153
+
154
+ gptq: Dict[str, GPTQ] = {}
155
+ for name in subset:
156
+ gptq[name] = GPTQ(subset[name])
157
+ gptq[name].quantizer.configure(
158
+ 8, perchannel=True, sym=False, mse=False
159
+ )
160
+ # Define a hook to collect input/output batches for quantizer calibration.
161
+ def add_batch(name):
162
+ def tmp(_, inp, out):
163
+ gptq[name].add_batch(inp[0].data, out.data)
164
+
165
+ return tmp
166
+
167
+ handles = []
168
+ for name in subset:
169
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
170
+ # Run the current layer on the stored calibration inputs to capture activation stats.
171
+ batch_num = self.cache_kwargs.pop("batch_num")
172
+ for batch_idx in range(batch_num):
173
+ cache_args_batch = gather_single_batch_from_list(
174
+ self.cache_args, batch_idx
175
+ )
176
+ cache_kwargs_batch = gather_single_batch_from_dict(
177
+ self.cache_kwargs, batch_idx
178
+ )
179
+ layer(*cache_args_batch, **cache_kwargs_batch)[0]
180
+ self.cache_kwargs["batch_num"] = batch_num
181
+ for h in handles:
182
+ h.remove()
183
+ # Quantize each submodule using the collected calibration data.
184
+ for name in subset:
185
+ if gptq_conf.verbose:
186
+ print(l_idx, name)
187
+ print("Quantizing ...")
188
+ gptq[name].fasterquant(
189
+ percdamp=0.01,
190
+ groupsize=-1,
191
+ actorder=True,
192
+ static_groups=False,
193
+ verbose=gptq_conf.verbose,
194
+ )
195
+ quantizers["model.layers.%d.%s" % (l_idx, name)] = gptq[
196
+ name
197
+ ].quantizer
198
+ gptq[name].free()
199
+ """
200
+ Execute the quantized layer with the calibration inputs to obtain ouptuts
201
+ that will serve as inputs for the next layer.
202
+
203
+ This ensures that the quantization effects are correctly propagated to subsequent
204
+ layers.
205
+ """
206
+ batch_num = self.cache_kwargs.pop("batch_num")
207
+ for batch_idx in range(batch_num):
208
+ cache_args_batch = gather_single_batch_from_list(
209
+ self.cache_args, batch_idx
210
+ )
211
+ cache_kwargs_batch = gather_single_batch_from_dict(
212
+ self.cache_kwargs, batch_idx
213
+ )
214
+ outs = layer(*cache_args_batch, **cache_kwargs_batch)[0]
215
+ # Update inputs for next iteration.
216
+ self.cache_args[0][batch_idx] = outs
217
+ self.cache_kwargs["batch_num"] = batch_num
218
+
219
+ if torch.cuda.is_available():
220
+ torch.cuda.empty_cache()
221
+ # Restore the original cache configuration.
222
+ if hasattr(model, "config"):
223
+ model.config.use_cache = use_cache
224
+
225
+ return model
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2024 Intel Corporation
2
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def find_layers(module, layers=[torch.nn.Linear], name=""):
20
+ if type(module) in layers:
21
+ return {name: module}
22
+ res = {}
23
+ for name1, child in module.named_children():
24
+ res.update(
25
+ find_layers(
26
+ child, layers=layers, name=name + "." + name1 if name != "" else name1
27
+ )
28
+ )
29
+ return res
30
+
31
+
32
+ def gather_single_batch_from_dict(data_dict, idx):
33
+ """
34
+ Gather single batch from a dict.
35
+
36
+ Args:
37
+ data_dict (dict): data dict.
38
+ idx (int): index
39
+
40
+ Returns:
41
+ dict: single batch.
42
+ """
43
+ # obtain a set of keyword input from cache
44
+ single_batch = {}
45
+ for k, v in data_dict.items():
46
+ single_batch[k] = data_dict[k][idx]
47
+ return single_batch
48
+
49
+
50
+ def gather_single_batch_from_list(data_list, idx):
51
+ """
52
+ Gather single batch from a list.
53
+
54
+ Args:
55
+ data_dict (dict): data list.
56
+ idx (int): index
57
+
58
+ Returns:
59
+ list: single batch.
60
+ """
61
+ # obtain a set of keyword input from cache
62
+ single_batch = []
63
+ for data_item in data_list:
64
+ single_batch.append(data_item[idx])
65
+ return single_batch
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,215 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import functools
18
+ from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
19
+
20
+ if TYPE_CHECKING:
21
+ import torch.fx
22
+ from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
23
+ import torch
24
+ from torch.ao.quantization.observer import (
25
+ MinMaxObserver,
26
+ MovingAverageMinMaxObserver,
27
+ MovingAveragePerChannelMinMaxObserver,
28
+ PerChannelMinMaxObserver,
29
+ )
30
+ from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
31
+ from torch.ao.quantization.quantizer.utils import _get_module_name_filter
32
+
33
+ from tico.experimental.quantization.algorithm.pt2e.annotation.op import *
34
+ import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
35
+ import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
36
+ import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
37
+ from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
38
+ QuantizationConfig,
39
+ )
40
+ from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
41
+ convert_scalars_to_attrs,
42
+ )
43
+
44
+
45
+ class PT2EAnnotator(Quantizer):
46
+ """
47
+ The class annotates quantization configurations on each nodes.
48
+
49
+ Observers would be attached according to those configurations in
50
+ 'torch.prepare_pt2e'.
51
+ """
52
+
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.global_config: Optional[QuantizationConfig] = None
56
+ self.operator_type_config: Dict[
57
+ torch._ops.OpOverloadPacket, QuantizationConfig
58
+ ] = {}
59
+ self.module_type_config: Dict[Callable, QuantizationConfig] = {}
60
+ self.module_name_config: Dict[str, QuantizationConfig] = {}
61
+
62
+ def set_global(self, quantization_config: QuantizationConfig) -> PT2EAnnotator:
63
+ """
64
+ Set quantization config globally.
65
+ """
66
+ assert quantization_config is not None
67
+ self.global_config = quantization_config
68
+ return self
69
+
70
+ def set_operator_type(
71
+ self,
72
+ operator_type: torch._ops.OpOverloadPacket,
73
+ quantization_config: QuantizationConfig,
74
+ ) -> PT2EAnnotator:
75
+ """
76
+ Set quantization config for given operator type.
77
+ """
78
+ assert quantization_config is not None
79
+ self.operator_type_config[operator_type] = quantization_config
80
+ return self
81
+
82
+ def set_module_type(
83
+ self, module_type: Callable, quantization_config: QuantizationConfig
84
+ ):
85
+ """
86
+ Set quantization config for given module type.
87
+
88
+ For example, let's say quantizer.set_module_type(nn.Linear).
89
+ It will quantize all 'nn.Linear' modules with the `quantization_config`.
90
+ """
91
+ assert quantization_config is not None
92
+ self.module_type_config[module_type] = quantization_config
93
+ return self
94
+
95
+ def set_module_name(
96
+ self, module_name: str, quantization_config: QuantizationConfig
97
+ ):
98
+ """
99
+ Set quantization config for given module name.
100
+
101
+ For example, let's say quantizer.set_module_name("blocks.sub").
102
+ It will quantize all nodes that come from a module whose name is "blocks.sub"
103
+ with the `quantization_config`.
104
+ """
105
+ assert quantization_config is not None
106
+ self.module_name_config[module_name] = quantization_config
107
+ return self
108
+
109
+ def transform_for_annotation(
110
+ self, model: torch.fx.GraphModule
111
+ ) -> torch.fx.GraphModule:
112
+ """Allows for user defined transforms to run before annotating the graph."""
113
+ model = convert_scalars_to_attrs(model)
114
+ return model
115
+
116
+ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
117
+ model = self._annotate_for_quantization(model)
118
+ annot_utils.propagate_annotation_forward(model)
119
+ return model
120
+
121
+ def validate(self, model: torch.fx.GraphModule) -> None:
122
+ # TODO Consider this method.
123
+ pass
124
+
125
+ def _annotate_by_config_and_filter(
126
+ self,
127
+ model: torch.fx.GraphModule,
128
+ quantization_config: Optional[QuantizationConfig],
129
+ filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
130
+ ) -> torch.fx.GraphModule:
131
+ assert quantization_config is not None
132
+
133
+ for node in model.graph.nodes:
134
+ if node.target not in annot_spec.OP_TO_ANNOTATOR:
135
+ continue
136
+ annot_spec.OP_TO_ANNOTATOR[node.target](
137
+ model, node, quantization_config, filter_fn
138
+ )
139
+ return model
140
+
141
+ def _annotate_for_quantization(
142
+ self, model: torch.fx.GraphModule
143
+ ) -> torch.fx.GraphModule:
144
+ # Annotate according to the given module names.
145
+ module_name_list = list(self.module_name_config.keys())
146
+ for module_name, config in self.module_name_config.items():
147
+ self._annotate_by_config_and_filter(
148
+ model, config, _get_module_name_filter(module_name)
149
+ )
150
+
151
+ # Annotate according to the given module types.
152
+ tp_list = list(self.module_type_config.keys())
153
+ for module_type, config in self.module_type_config.items():
154
+ self._annotate_by_config_and_filter(
155
+ model, config, quant_utils.get_module_type_filter(module_type)
156
+ )
157
+
158
+ # TODO Annotate according to the given operator types.
159
+
160
+ self._annotate_by_config_and_filter(
161
+ model,
162
+ self.global_config,
163
+ quant_utils.get_not_module_type_or_name_filter(tp_list, module_name_list),
164
+ )
165
+ return model
166
+
167
+
168
+ @functools.lru_cache
169
+ def get_asymmetric_quantization_config(
170
+ weight_is_per_channel: bool = True,
171
+ act_qmin: int = 0,
172
+ act_qmax: int = 255,
173
+ weight_qmin: int = 0,
174
+ weight_qmax: int = 255,
175
+ ) -> QuantizationConfig:
176
+ # activation
177
+ act_extra_args: Dict[str, Any] = {"eps": 2**-12}
178
+ act_observer = MinMaxObserver
179
+ act_qspec = QuantizationSpec(
180
+ dtype=torch.uint8,
181
+ quant_min=act_qmin,
182
+ quant_max=act_qmax,
183
+ qscheme=torch.per_tensor_affine,
184
+ is_dynamic=False,
185
+ observer_or_fake_quant_ctr=act_observer.with_args(
186
+ **act_extra_args,
187
+ ),
188
+ )
189
+ # weight
190
+ weight_extra_args: Dict[str, Any] = {"eps": 2**-12}
191
+ weight_qscheme = (
192
+ torch.per_channel_affine if weight_is_per_channel else torch.per_tensor_affine
193
+ )
194
+ weight_observer: _ObserverOrFakeQuantizeConstructor = (
195
+ PerChannelMinMaxObserver if weight_is_per_channel else MinMaxObserver
196
+ )
197
+ weight_qspec = QuantizationSpec(
198
+ dtype=torch.uint8,
199
+ quant_min=weight_qmin,
200
+ quant_max=weight_qmax,
201
+ qscheme=weight_qscheme,
202
+ ch_axis=0,
203
+ is_dynamic=False,
204
+ observer_or_fake_quant_ctr=weight_observer.with_args(**weight_extra_args),
205
+ )
206
+
207
+ # Set bias qspec in each annotation functions.
208
+ bias_qspec = None
209
+ quantization_config = QuantizationConfig(
210
+ act_qspec,
211
+ act_qspec,
212
+ weight_qspec,
213
+ bias_qspec,
214
+ )
215
+ return quantization_config
@@ -0,0 +1,26 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ from torch.ao.quantization.quantizer import QuantizationSpec
19
+
20
+
21
+ @dataclass(eq=True, frozen=True)
22
+ class QuantizationConfig:
23
+ input_activation: Optional[QuantizationSpec]
24
+ output_activation: Optional[QuantizationSpec]
25
+ weight: Optional[QuantizationSpec]
26
+ bias: Optional[QuantizationSpec]
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ from os.path import basename, dirname, isfile, join
17
+
18
+ modules = glob.glob(join(dirname(__file__), "*.py"))
19
+ __all__ = [
20
+ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
21
+ ]
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.ao.quantization.quantizer import SharedQuantizationSpec
21
+
22
+ import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
23
+ import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
24
+ import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
25
+ from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
26
+ QuantizationConfig,
27
+ )
28
+ from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
29
+
30
+
31
+ @annot_spec.register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
32
+ def _annotate_adaptive_avg_pool2d(
33
+ gm: torch.fx.GraphModule,
34
+ node: torch.fx.Node,
35
+ quantization_config: Optional[QuantizationConfig],
36
+ filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
37
+ ):
38
+ if (
39
+ node.op != "call_function"
40
+ or node.target != torch.ops.aten.adaptive_avg_pool2d.default
41
+ ):
42
+ return
43
+ if filter_fn and not filter_fn(node):
44
+ return
45
+ if quant_utils.is_annotated(node):
46
+ return
47
+
48
+ args = AdaptiveAvgPool2dArgs(*node.args) # type: ignore[arg-type]
49
+ input = args.input
50
+
51
+ assert isinstance(input, torch.fx.Node)
52
+ if (
53
+ "quantization_annotation" not in input.meta
54
+ or not input.meta["quantization_annotation"]._annotated
55
+ or input.meta["quantization_annotation"].output_qspec is None
56
+ ):
57
+ input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
58
+ else:
59
+ input_act_qspec = SharedQuantizationSpec(input)
60
+ annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
61
+
62
+ output_act_qspec = SharedQuantizationSpec((input, node))
63
+ annot_utils.annotate_output_qspec(node, output_act_qspec)
64
+
65
+ annot_utils.mark_nodes_as_annotated(node)
@@ -0,0 +1,57 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+
21
+ import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
22
+ import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
23
+ import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
24
+ from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
+ QuantizationConfig,
26
+ )
27
+ from tico.utils.validate_args_kwargs import AddTensorArgs
28
+
29
+
30
+ @annot_spec.register_annotator([torch.ops.aten.add.Tensor])
31
+ def _annotate_add(
32
+ gm: torch.fx.GraphModule,
33
+ node: torch.fx.Node,
34
+ quantization_config: Optional[QuantizationConfig],
35
+ filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
36
+ ):
37
+ if node.op != "call_function" or node.target != torch.ops.aten.add.Tensor:
38
+ return
39
+ if filter_fn and not filter_fn(node):
40
+ return
41
+ if quant_utils.is_annotated(node):
42
+ return
43
+
44
+ args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input_ = args.input
46
+ other = args.other
47
+
48
+ input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
49
+ if isinstance(input_, torch.fx.Node):
50
+ annot_utils.annotate_input_qspec_map(node, input_, input_act_qspec)
51
+ if isinstance(other, torch.fx.Node):
52
+ annot_utils.annotate_input_qspec_map(node, other, input_act_qspec)
53
+
54
+ output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
55
+ annot_utils.annotate_output_qspec(node, output_act_qspec)
56
+
57
+ annot_utils.mark_nodes_as_annotated(node)