tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,459 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ from collections import defaultdict
22
+ from typing import Any
23
+
24
+ import torch
25
+ from torch.export import ExportedProgram
26
+
27
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
28
+ from tico.utils import logging
29
+ from tico.utils.errors import NotYetSupportedError
30
+ from tico.utils.graph import create_node
31
+ from tico.utils.passes import PassBase, PassResult
32
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
33
+ from tico.utils.utils import quant_min_max, set_new_meta_val
34
+ from tico.utils.validate_args_kwargs import (
35
+ AddTensorArgs,
36
+ BmmArgs,
37
+ CatArgs,
38
+ LinearArgs,
39
+ MulTensorArgs,
40
+ PermuteArgs,
41
+ ReluArgs,
42
+ ReshapeArgs,
43
+ )
44
+
45
+
46
+ def qparam_dtype(node: torch.fx.Node) -> str:
47
+ assert QPARAM_KEY in node.meta
48
+ return node.meta[QPARAM_KEY].dtype
49
+
50
+
51
+ # Convert i16 qparam to u8 qparam
52
+ # scale and zero_point are inferred from i16 qparam
53
+ def _i16_to_u8(qparam: QuantParam) -> QuantParam:
54
+ # Assume per-tensor quantization
55
+ assert qparam.scale is not None and len(qparam.scale) == 1
56
+ assert qparam.dtype == "int16"
57
+
58
+ s16_scale = qparam.scale[0]
59
+ max_ = s16_scale * 32767 # numeric_limits<int16>
60
+ min_ = -max_
61
+
62
+ u8_scale = (max_ - min_) / 255
63
+ u8_zerop = round(-min_ / u8_scale)
64
+
65
+ new_qparam = QuantParam()
66
+ new_qparam.scale = [u8_scale]
67
+ new_qparam.zero_point = [u8_zerop]
68
+ new_qparam.dtype = "uint8"
69
+
70
+ return new_qparam
71
+
72
+
73
+ # Convert u8 qparam to i16 qparam
74
+ # scale is inferred from u8 qparam
75
+ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
76
+ # Assume per-tensor quantization
77
+ assert qparam.scale is not None and len(qparam.scale) == 1
78
+ assert qparam.zero_point is not None and len(qparam.zero_point) == 1
79
+ assert qparam.dtype == "uint8"
80
+
81
+ u8_scale = qparam.scale[0]
82
+ u8_zerop = qparam.zero_point[0]
83
+ max_ = u8_scale * (255 - u8_zerop)
84
+ min_ = u8_scale * (-u8_zerop)
85
+
86
+ abs_max = max(abs(max_), abs(min_))
87
+ s16_scale = abs_max / 32767
88
+ s16_zerop = 0
89
+
90
+ new_qparam = QuantParam()
91
+ new_qparam.scale = [s16_scale]
92
+ new_qparam.zero_point = [s16_zerop]
93
+ new_qparam.dtype = "int16"
94
+
95
+ return new_qparam
96
+
97
+
98
+ def _insert_quantize_op_before(node, inp):
99
+ graph = node.graph
100
+ qparam: QuantParam = node.meta[QPARAM_KEY]
101
+ assert qparam.scale is not None
102
+ assert qparam.zero_point is not None
103
+ scale = qparam.scale[0]
104
+ zerop = qparam.zero_point[0]
105
+ min_, max_ = quant_min_max(qparam.dtype)
106
+ dtype = getattr(torch, qparam.dtype)
107
+
108
+ with graph.inserting_before(node):
109
+ q_args = (inp, scale, zerop, min_, max_, dtype)
110
+ quantize = create_node(
111
+ graph,
112
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
113
+ args=q_args,
114
+ origin=node,
115
+ )
116
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
117
+ set_new_meta_val(quantize)
118
+
119
+ node.replace_input_with(inp, quantize)
120
+
121
+ return quantize
122
+
123
+
124
+ def _insert_quantize_op_after(node):
125
+ graph = node.graph
126
+ qparam: QuantParam = node.meta[QPARAM_KEY]
127
+ assert qparam.scale is not None
128
+ assert qparam.zero_point is not None
129
+ scale = qparam.scale[0]
130
+ zerop = qparam.zero_point[0]
131
+ min_, max_ = quant_min_max(qparam.dtype)
132
+ dtype = getattr(torch, qparam.dtype)
133
+ with graph.inserting_after(node):
134
+ q_args = (node, scale, zerop, min_, max_, dtype)
135
+ quantize = create_node(
136
+ graph,
137
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
138
+ args=q_args,
139
+ )
140
+
141
+ node.replace_all_uses_with(quantize, propagate_meta=True)
142
+ quantize.replace_input_with(quantize, node)
143
+
144
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
145
+
146
+ return quantize
147
+
148
+
149
+ def _linear_handler(node, logger):
150
+ lin_args = LinearArgs(*node.args, **node.kwargs)
151
+ inp = lin_args.input
152
+
153
+ if QPARAM_KEY not in inp.meta:
154
+ return
155
+
156
+ if QPARAM_KEY not in node.meta:
157
+ return
158
+
159
+ if qparam_dtype(inp) == qparam_dtype(node):
160
+ return
161
+
162
+ if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
163
+ quantize = _insert_quantize_op_after(node)
164
+
165
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
166
+
167
+ # Update node's qparam from i16 to u8
168
+ # NOTE This would severely degrade accuracy. It is
169
+ # important to mitigate this accuracy drop in backend.
170
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
171
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
172
+ else:
173
+ raise NotYetSupportedError(
174
+ f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
175
+ )
176
+
177
+
178
+ def _add_handler(node, logger):
179
+ add_args = AddTensorArgs(*node.args, **node.kwargs)
180
+ x = add_args.input
181
+ y = add_args.other
182
+
183
+ if not isinstance(x, torch.fx.Node):
184
+ return
185
+ if not isinstance(y, torch.fx.Node):
186
+ return
187
+
188
+ if QPARAM_KEY not in x.meta:
189
+ return
190
+ if QPARAM_KEY not in y.meta:
191
+ return
192
+ if QPARAM_KEY not in node.meta:
193
+ return
194
+
195
+ if qparam_dtype(x) == qparam_dtype(node):
196
+ return
197
+
198
+ if qparam_dtype(x) != qparam_dtype(y):
199
+ return
200
+
201
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
202
+ quantize = _insert_quantize_op_after(node)
203
+
204
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
205
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
206
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
207
+ else:
208
+ raise NotYetSupportedError("Unsupported dtype")
209
+
210
+
211
+ def _mul_handler(node, logger):
212
+ mul_args = MulTensorArgs(*node.args, **node.kwargs)
213
+ x = mul_args.input
214
+ y = mul_args.other
215
+
216
+ if not isinstance(x, torch.fx.Node):
217
+ return
218
+ if not isinstance(y, torch.fx.Node):
219
+ return
220
+
221
+ if QPARAM_KEY not in x.meta:
222
+ return
223
+ if QPARAM_KEY not in y.meta:
224
+ return
225
+ if QPARAM_KEY not in node.meta:
226
+ return
227
+
228
+ if qparam_dtype(x) == qparam_dtype(node):
229
+ return
230
+
231
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
232
+ quantize = _insert_quantize_op_after(node)
233
+
234
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
235
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
236
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
237
+ else:
238
+ raise NotYetSupportedError("Unsupported dtype")
239
+
240
+
241
+ def _cat_handler(node, logger):
242
+ cat_args = CatArgs(*node.args, **node.kwargs)
243
+ tensors = cat_args.tensors
244
+
245
+ if any(QPARAM_KEY not in x.meta for x in tensors):
246
+ return
247
+
248
+ if QPARAM_KEY not in node.meta:
249
+ return
250
+
251
+ assert len(tensors) > 0
252
+ in_dtype = qparam_dtype(tensors[0])
253
+ if in_dtype == qparam_dtype(node):
254
+ return
255
+
256
+ if any(qparam_dtype(x) != in_dtype for x in tensors):
257
+ return
258
+
259
+ if in_dtype == "int16" and qparam_dtype(node) == "uint8":
260
+ quantize = _insert_quantize_op_after(node)
261
+
262
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
263
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
264
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
265
+ else:
266
+ raise NotYetSupportedError("Unsupported dtype")
267
+
268
+
269
+ def _bmm_handler(node, logger):
270
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
271
+ x = bmm_args.input
272
+ y = bmm_args.mat2
273
+
274
+ if QPARAM_KEY not in x.meta:
275
+ return
276
+ if QPARAM_KEY not in y.meta:
277
+ return
278
+ if QPARAM_KEY not in node.meta:
279
+ return
280
+
281
+ if qparam_dtype(x) == qparam_dtype(node):
282
+ return
283
+
284
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
285
+ quantize = _insert_quantize_op_after(node)
286
+
287
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
288
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
289
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
290
+ elif qparam_dtype(x) == "uint8" and qparam_dtype(node) == "int16":
291
+ quantize = _insert_quantize_op_after(node)
292
+
293
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
294
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
295
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
296
+ else:
297
+ raise NotYetSupportedError("Unsupported dtype")
298
+
299
+
300
+ def _permute_handler(node, logger):
301
+ per_args = PermuteArgs(*node.args, **node.kwargs)
302
+ inp = per_args.input
303
+
304
+ if QPARAM_KEY not in inp.meta:
305
+ return
306
+
307
+ if QPARAM_KEY not in node.meta:
308
+ return
309
+
310
+ if qparam_dtype(inp) == qparam_dtype(node):
311
+ return
312
+
313
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
314
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
315
+ # permute Op to reduce tensor size ealier
316
+ quantize = _insert_quantize_op_before(node, inp)
317
+
318
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
319
+ logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
320
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
321
+ quantize = _insert_quantize_op_after(node)
322
+
323
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
324
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
325
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
326
+ else:
327
+ raise NotYetSupportedError("Unsupported dtype")
328
+
329
+
330
+ def _reshape_handler(node, logger):
331
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
332
+ inp = reshape_args.input
333
+
334
+ if QPARAM_KEY not in inp.meta:
335
+ return
336
+
337
+ if QPARAM_KEY not in node.meta:
338
+ return
339
+
340
+ if qparam_dtype(inp) == qparam_dtype(node):
341
+ return
342
+
343
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
344
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
345
+ # reshape Op to reduce tensor size ealier
346
+ quantize = _insert_quantize_op_before(node, inp)
347
+
348
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
349
+ logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
350
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
351
+ quantize = _insert_quantize_op_after(node)
352
+
353
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
354
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
355
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
356
+ else:
357
+ raise NotYetSupportedError("Unsupported dtype")
358
+
359
+
360
+ def _relu_handler(node, logger):
361
+ relu_args = ReluArgs(*node.args, **node.kwargs)
362
+ inp = relu_args.input
363
+
364
+ if QPARAM_KEY not in inp.meta:
365
+ return
366
+
367
+ if QPARAM_KEY not in node.meta:
368
+ return
369
+
370
+ if qparam_dtype(inp) == qparam_dtype(node):
371
+ return
372
+
373
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
374
+ quantize = _insert_quantize_op_after(node)
375
+
376
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
377
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
378
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
379
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
380
+ quantize = _insert_quantize_op_after(node)
381
+
382
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
383
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
384
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
385
+ else:
386
+ raise NotYetSupportedError("Unsupported dtype")
387
+
388
+
389
+ _op_handler: defaultdict[Any, Any | None] = defaultdict(lambda: None)
390
+ _op_handler[torch.ops.aten.linear.default] = _linear_handler
391
+ _op_handler[torch.ops.aten.add.Tensor] = _add_handler
392
+ _op_handler[torch.ops.aten.mul.Tensor] = _mul_handler
393
+ _op_handler[torch.ops.aten.cat.default] = _cat_handler
394
+ _op_handler[torch.ops.aten.bmm.default] = _bmm_handler
395
+ _op_handler[torch.ops.aten.permute.default] = _permute_handler
396
+ _op_handler[torch.ops.aten.reshape.default] = _reshape_handler
397
+ _op_handler[torch.ops.aten.relu.default] = _relu_handler
398
+
399
+
400
+ @trace_graph_diff_on_pass
401
+ class InsertQuantizeOnDtypeMismatch(PassBase):
402
+ """
403
+ Insert quantize Op in the operators where circle's type inference is violated.
404
+ Example. FullyConnected
405
+ [BEFORE]
406
+ Op (uint8) - aten.linear.default (int16)
407
+ [AFTER]
408
+ Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
409
+ Why is this pass necessary?
410
+ - For some operators, circle's type inference pass overwrites the input's dtype to
411
+ the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
412
+ output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
413
+ This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
414
+ - To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
415
+ - NOTE For some cases, Quantize Op is inserted before the operators.
416
+
417
+ Let's assume Reshape Op's input is int16 and output is uint8. There are two possible places to insert
418
+ Quantize Op.
419
+
420
+ 1. Insert Quantize before Reshape.
421
+
422
+ ```
423
+ Predecessor (int16)-> Quantize (uint8) -> Reshape (uint8) -> ...
424
+ ```
425
+
426
+ 2. Insert Quantize after Reshape.
427
+
428
+ ```
429
+ Predecessor (int16)-> Reshape (int16) -> Quantize (uint8) -> ...
430
+ ```
431
+
432
+ Comparing 1) and 2), the difference is that Reshape operation is conducted in uint8 or int16.
433
+ We go with 1), which does Reshape in uint8, for faster execution. Note that Reshape Op does not
434
+ change the value, so its dytpe does not affect accuracy.
435
+ """
436
+
437
+ def __init__(self):
438
+ super().__init__()
439
+
440
+ def call(self, exported_program: ExportedProgram) -> PassResult:
441
+ logger = logging.getLogger(__name__)
442
+
443
+ graph_module = exported_program.graph_module
444
+ graph: torch.fx.Graph = graph_module.graph
445
+
446
+ for node in graph.nodes:
447
+ if node.op != "call_function":
448
+ continue
449
+
450
+ handler = _op_handler[node.target]
451
+ if handler is not None:
452
+ handler(node, logger)
453
+
454
+ graph.eliminate_dead_code()
455
+ graph.lint()
456
+ graph_module.recompile()
457
+
458
+ # Run only once.
459
+ return PassResult(False)
@@ -16,7 +16,6 @@ from typing import TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
- import copy
20
19
 
21
20
  import torch
22
21
  from torch.export import ExportedProgram
@@ -53,7 +53,7 @@ class ValRange:
53
53
  if isinstance(val, torch.Tensor):
54
54
  self.max = torch.max(val).item()
55
55
  self.min = torch.min(val).item()
56
- elif type(val) == list:
56
+ elif isinstance(val, list):
57
57
  self.max = max(val)
58
58
  self.min = min(val)
59
59
  else:
@@ -13,25 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import copy
16
- from typing import Any, Dict, Optional, Type
16
+ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
- from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
- from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
- SmoothQuantQuantizer,
24
- )
25
- from tico.experimental.quantization.config import BaseConfig
26
- from tico.experimental.quantization.quantizer import BaseQuantizer
20
+ from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
+ from tico.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
+ from tico.quantization.config.base import BaseConfig
23
+ from tico.quantization.quantizer import BaseQuantizer
24
+ from tico.quantization.quantizer_registry import get_quantizer
27
25
 
28
26
 
29
- config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
30
- "pt2e": PT2EQuantizer,
31
- "gptq": GPTQQuantizer,
32
- "smooth_quant": SmoothQuantQuantizer,
33
- }
34
-
35
27
  QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
36
28
 
37
29
 
@@ -40,7 +32,7 @@ def prepare(
40
32
  quant_config: BaseConfig,
41
33
  args: Optional[Any] = None,
42
34
  kwargs: Optional[Dict[str, Any]] = None,
43
- inplace: Optional[bool] = False,
35
+ inplace: Optional[bool] = True,
44
36
  ):
45
37
  """
46
38
  Prepare the model for quantization using the provided configuration.
@@ -59,21 +51,24 @@ def prepare(
59
51
  Returns:
60
52
  The model prepared for quantization.
61
53
  """
62
- if quant_config.name == "pt2e" and inplace:
54
+ if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
55
+ raise RuntimeError("prepare() already has been called.")
56
+ quantizer = get_quantizer(quant_config)
57
+
58
+ if isinstance(quantizer, PT2EQuantizer) and inplace:
63
59
  raise RuntimeError(
64
60
  "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
65
61
  )
66
62
 
67
63
  model = model if inplace else copy.deepcopy(model)
68
64
 
69
- quantizer = config_to_quantizer[quant_config.name](quant_config)
70
65
  model = quantizer.prepare(model, args, kwargs)
71
66
  setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
72
67
 
73
68
  return model
74
69
 
75
70
 
76
- def convert(model, inplace: Optional[bool] = False):
71
+ def convert(model, inplace: Optional[bool] = True):
77
72
  """
78
73
  Convert the prepared model to a quantized model using the provided configuration.
79
74
 
@@ -99,6 +94,12 @@ def convert(model, inplace: Optional[bool] = False):
99
94
  raise RuntimeError(
100
95
  "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
101
96
  )
97
+ # deepcopy prevents the quantizer from restoring the catcher used for calibration.
98
+ # TODO Revisit `inplace` policy.
99
+ if isinstance(quantizer, GPTQQuantizer) and not inplace:
100
+ raise RuntimeError(
101
+ "GPTQ quantization only supports `in-place=True`. Please set 'inplace=True' to proceed."
102
+ )
102
103
 
103
104
  model = model if inplace else copy.deepcopy(model)
104
105
 
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.config import BaseConfig
20
+ from tico.quantization.config.base import BaseConfig
21
21
 
22
22
 
23
23
  class BaseQuantizer(ABC):
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from typing import Dict, Optional, Type, TypeVar
17
+
18
+ from tico.quantization.config.base import BaseConfig
19
+ from tico.quantization.quantizer import BaseQuantizer
20
+
21
+ TQ = TypeVar("TQ", bound=BaseQuantizer)
22
+
23
+ # Mapping: Config type -> Quantizer type
24
+ _REGISTRY: Dict[Type[BaseConfig], Type[BaseQuantizer]] = {}
25
+
26
+
27
+ def register_quantizer(config_cls: Type[BaseConfig]):
28
+ """
29
+ Decorator to register a quantizer for a given config class.
30
+ Usage:
31
+ @register_quantizer(GPTQConfig)
32
+ class GPTQQuantizer(BaseQuantizer): ...
33
+ """
34
+
35
+ def wrapper(quantizer_cls: Type[TQ]) -> Type[TQ]:
36
+ _REGISTRY[config_cls] = quantizer_cls
37
+ return quantizer_cls
38
+
39
+ return wrapper
40
+
41
+
42
+ def _lookup(cfg: BaseConfig) -> Optional[Type[BaseQuantizer]]:
43
+ """Return a quantizer class only if the exact config type is registered."""
44
+ return _REGISTRY.get(type(cfg))
45
+
46
+
47
+ def get_quantizer(cfg: BaseConfig) -> BaseQuantizer:
48
+ """Factory to return a quantizer instance for the given config."""
49
+ qcls = _lookup(cfg)
50
+ if qcls is not None:
51
+ return qcls(cfg)
52
+
53
+ # Lazy import by naming convention
54
+ name = getattr(cfg, "name", None)
55
+ if name:
56
+ if name == "ptq":
57
+ importlib.import_module(f"tico.quantization.wrapq.quantizer")
58
+ else:
59
+ try:
60
+ importlib.import_module(f"tico.quantization.algorithm.{name}.quantizer")
61
+ except Exception as e:
62
+ raise RuntimeError(
63
+ f"Failed to import quantizer module for config name='{name}': {e}"
64
+ )
65
+
66
+ qcls = _lookup(cfg)
67
+ if qcls is not None:
68
+ return qcls(cfg)
69
+
70
+ raise RuntimeError(
71
+ f"No quantizer registered for config type {type(cfg).__name__} "
72
+ f"(name='{getattr(cfg,'name',None)}')."
73
+ )
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE