tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,91 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import CatArgs, PermuteArgs, ReshapeArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class PropagateQParamBackward(PassBase):
33
+ """
34
+ This pass propagates quantization parameters backward.
35
+
36
+ BEFORE)
37
+
38
+ node -> reshape (with meta[QPARAM_KEY])
39
+
40
+ AFTER)
41
+
42
+ node (with meta[QPARAM_KEY]) -> reshape (with meta[QPARAM_KEY])
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ def call(self, exported_program: ExportedProgram) -> PassResult:
49
+ logger = logging.getLogger(__name__)
50
+
51
+ graph_module = exported_program.graph_module
52
+ graph: torch.fx.Graph = graph_module.graph
53
+
54
+ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
55
+ if QPARAM_KEY not in src.meta:
56
+ return
57
+
58
+ if (
59
+ QPARAM_KEY in dst.meta
60
+ and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
61
+ ):
62
+ return
63
+
64
+ dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
65
+
66
+ logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
67
+
68
+ # Do reverse-order traversal for backward propagation
69
+ for node in reversed(graph.nodes):
70
+ if node.op != "call_function":
71
+ continue
72
+ if node.target == torch.ops.aten.cat.default:
73
+ concat_args = CatArgs(*node.args, **node.kwargs)
74
+ concat_inputs = concat_args.tensors
75
+
76
+ for concat_input in concat_inputs:
77
+ _propagate_qparam_if_possible(node, concat_input)
78
+ elif node.target == torch.ops.aten.reshape.default:
79
+ args = ReshapeArgs(*node.args, **node.kwargs)
80
+ _propagate_qparam_if_possible(node, args.input)
81
+ elif node.target == torch.ops.aten.permute.default:
82
+ permute_args = PermuteArgs(*node.args, **node.kwargs)
83
+ _propagate_qparam_if_possible(node, permute_args.input)
84
+ # TODO Support more ops.
85
+
86
+ graph.eliminate_dead_code()
87
+ graph.lint()
88
+ graph_module.recompile()
89
+
90
+ # Run only once.
91
+ return PassResult(False)
@@ -0,0 +1,141 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import (
29
+ CatArgs,
30
+ NegArgs,
31
+ PermuteArgs,
32
+ ReshapeArgs,
33
+ SliceArgs,
34
+ )
35
+
36
+
37
+ @trace_graph_diff_on_pass
38
+ class PropagateQParamForward(PassBase):
39
+ """
40
+ A pass propagates quantization parameters through operations that do not alter them.
41
+
42
+ This pass identifies and propagates quantization parameters through operations that
43
+ do not change their values, such as `permute`, `reshape`, `transpose`, `view` and
44
+ similar tensor transformations.
45
+
46
+ By ensuring that quantization parameters remain consistent across such operations,
47
+ this pass helps maintain correctness in quantization-aware representations.
48
+ """
49
+
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ def call(self, exported_program: ExportedProgram) -> PassResult:
54
+ logger = logging.getLogger(__name__)
55
+
56
+ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
57
+ if QPARAM_KEY not in src.meta:
58
+ return
59
+
60
+ if (
61
+ QPARAM_KEY in dst.meta
62
+ and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
63
+ ):
64
+ return
65
+
66
+ dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
67
+
68
+ logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
69
+
70
+ graph_module = exported_program.graph_module
71
+ graph: torch.fx.Graph = graph_module.graph
72
+ for node in graph.nodes:
73
+ if node.op != "call_function":
74
+ continue
75
+ if node.target == torch.ops.aten.permute.default:
76
+ permute_args = PermuteArgs(*node.args, **node.kwargs)
77
+ _propagate_qparam_if_possible(permute_args.input, node)
78
+ elif node.target == torch.ops.aten.reshape.default:
79
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
80
+ _propagate_qparam_if_possible(reshape_args.input, node)
81
+ elif node.target == torch.ops.aten.slice.Tensor:
82
+ slice_args = SliceArgs(*node.args, **node.kwargs)
83
+ _propagate_qparam_if_possible(slice_args.input, node)
84
+ elif node.target == torch.ops.aten.neg.default:
85
+ neg_args = NegArgs(*node.args, **node.kwargs)
86
+
87
+ if QPARAM_KEY not in neg_args.input.meta:
88
+ continue
89
+ # Only support int16 for now
90
+ if neg_args.input.meta[QPARAM_KEY].dtype != "int16":
91
+ continue
92
+
93
+ _propagate_qparam_if_possible(neg_args.input, node)
94
+
95
+ elif node.target == torch.ops.aten.cat.default:
96
+ concat_args = CatArgs(*node.args, **node.kwargs)
97
+ concat_inputs = concat_args.tensors
98
+
99
+ cond = True
100
+ for concat_input in concat_inputs:
101
+ # Check all inputs have qparam
102
+ if QPARAM_KEY not in concat_input.meta:
103
+ cond = False
104
+ break
105
+
106
+ # Only support int16 for now
107
+ if concat_input.meta[QPARAM_KEY].dtype != "int16":
108
+ cond = False
109
+ break
110
+
111
+ if concat_input.meta[QPARAM_KEY].scale is None:
112
+ cond = False
113
+ break
114
+
115
+ if len(concat_input.meta[QPARAM_KEY].scale) != 1:
116
+ cond = False
117
+ break
118
+
119
+ if not cond:
120
+ continue
121
+
122
+ # Find max scale node
123
+ max_scale = 0.0
124
+ max_scale_node = None
125
+ for concat_input in concat_inputs:
126
+ scale = concat_input.meta[QPARAM_KEY].scale[0]
127
+ if max_scale < scale:
128
+ max_scale = scale
129
+ max_scale_node = concat_input
130
+
131
+ assert max_scale_node is not None
132
+ _propagate_qparam_if_possible(max_scale_node, node)
133
+
134
+ # TODO Support more ops.
135
+
136
+ graph.eliminate_dead_code()
137
+ graph.lint()
138
+ graph_module.recompile()
139
+
140
+ # Run only once.
141
+ return PassResult(False)
@@ -0,0 +1,123 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
25
+ from tico.utils import logging
26
+ from tico.utils.graph import add_placeholder, get_torch_param_value, is_torch_param
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import LinearArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class QuantizeBias(PassBase):
34
+ """
35
+ Quantize bias.
36
+
37
+ This pass identifies fp32 biases, quantizes them using scales of input and weights.
38
+
39
+ This pass assumes that if bias is fp32, input and weights must have been quantized.
40
+ """
41
+
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ logger = logging.getLogger(__name__)
47
+
48
+ graph_module = exported_program.graph_module
49
+ graph: torch.fx.Graph = graph_module.graph
50
+ for node in graph.nodes:
51
+ if node.op != "call_function":
52
+ continue
53
+ if node.target == torch.ops.aten.linear.default:
54
+ lin_args = LinearArgs(*node.args, **node.kwargs)
55
+ inp = lin_args.input
56
+ weights = lin_args.weight
57
+ bias = lin_args.bias
58
+
59
+ if bias is None:
60
+ continue
61
+
62
+ # Only support bias is Parameter
63
+ # TODO Is it possible that bias is not Parameter?
64
+ if not is_torch_param(bias, exported_program):
65
+ continue
66
+
67
+ bias_val: torch.Tensor = get_torch_param_value(bias, exported_program)
68
+ if bias_val.dtype != torch.float32:
69
+ continue
70
+
71
+ if QPARAM_KEY not in inp.meta:
72
+ continue
73
+
74
+ if QPARAM_KEY not in weights.meta:
75
+ continue
76
+
77
+ quant_dtype = None
78
+ if inp.meta[QPARAM_KEY].dtype == "int16":
79
+ quant_dtype = torch.int64
80
+ elif inp.meta[QPARAM_KEY].dtype == "uint8":
81
+ quant_dtype = torch.int32
82
+ else:
83
+ continue
84
+
85
+ type_info = torch.iinfo(quant_dtype)
86
+
87
+ assert quant_dtype is not None
88
+
89
+ i_scale = inp.meta[QPARAM_KEY].scale
90
+ w_scale = weights.meta[QPARAM_KEY].scale
91
+
92
+ assert i_scale is not None
93
+ assert w_scale is not None
94
+ assert len(i_scale) == 1
95
+ assert len(w_scale) == bias_val.shape[0]
96
+
97
+ bias_scale = torch.tensor(i_scale) * torch.tensor(w_scale)
98
+ q_bias = torch.round(bias_val / bias_scale)
99
+ q_bias = torch.clamp(q_bias, min=type_info.min, max=type_info.max)
100
+ q_bias = q_bias.to(quant_dtype)
101
+
102
+ q_bias_node = add_placeholder(exported_program, q_bias, bias.name)
103
+
104
+ qparam = QuantParam()
105
+ qparam.scale = bias_scale.tolist()
106
+ assert qparam.scale is not None
107
+ qparam.zero_point = [0] * len(qparam.scale)
108
+ qparam.dtype = to_qparam_dtype(quant_dtype)
109
+ qparam.quantized_dimension = 0
110
+ q_bias_node.meta[QPARAM_KEY] = qparam
111
+
112
+ node.update_arg(2, q_bias_node)
113
+
114
+ logger.debug(f"Bias ({bias.name}) is quantized to {q_bias_node.name}.")
115
+
116
+ # TODO Support more ops.
117
+
118
+ graph.eliminate_dead_code()
119
+ graph.lint()
120
+ graph_module.recompile()
121
+
122
+ # Run only once.
123
+ return PassResult(False)
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._export.utils import (
21
+ get_buffer,
22
+ get_lifted_tensor_constant,
23
+ is_buffer,
24
+ is_lifted_tensor_constant,
25
+ )
26
+ from torch._subclasses.fake_tensor import FakeTensor
27
+ from torch.export import ExportedProgram
28
+
29
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
30
+ from tico.utils import logging
31
+ from tico.utils.passes import PassBase, PassResult
32
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
33
+ from tico.utils.validate_args_kwargs import (
34
+ DequantizePerChannelArgs,
35
+ DequantizePerTensorArgs,
36
+ )
37
+
38
+
39
+ def get_constant(exported_program: ExportedProgram, node: torch.fx.Node):
40
+ assert isinstance(node, torch.fx.Node)
41
+ if node.name in exported_program.constants:
42
+ return exported_program.constants[node.name]
43
+ elif is_buffer(exported_program, node):
44
+ return get_buffer(exported_program, node)
45
+ elif is_lifted_tensor_constant(exported_program, node):
46
+ return get_lifted_tensor_constant(exported_program, node)
47
+ else:
48
+ raise RuntimeError("NYI constant")
49
+
50
+
51
+ class ValRange:
52
+ def __init__(self, val: Union[torch.Tensor, List[int]]):
53
+ if isinstance(val, torch.Tensor):
54
+ self.max = torch.max(val).item()
55
+ self.min = torch.min(val).item()
56
+ elif type(val) == list:
57
+ self.max = max(val)
58
+ self.min = min(val)
59
+ else:
60
+ raise RuntimeError("Wrong dtype (val)")
61
+
62
+ def within(self, min_val, max_val):
63
+ return self.min >= min_val and self.max <= max_val
64
+
65
+
66
+ # Infer dtype using weight, zero point, and dtype
67
+ def infer_dtype(weight: torch.Tensor, zerop: List[int], dtype: torch.dtype) -> str:
68
+ weight_val = ValRange(weight)
69
+ zp_val = ValRange(zerop)
70
+
71
+ if weight_val.within(0, 15) and zp_val.within(0, 15) and dtype == torch.uint8:
72
+ return "uint4"
73
+ else:
74
+ return to_qparam_dtype(dtype)
75
+
76
+
77
+ @trace_graph_diff_on_pass
78
+ class RemoveWeightDequantOp(PassBase):
79
+ """
80
+ This pass identifies and removes any remaining Dequantize ops associated with
81
+ quantized weights.
82
+
83
+ Since weights already quantized earlier (and possibly kept in float by
84
+ attaching a DQ), the final stage of the quantization pipeline typically
85
+ does not require those DQ ops anymore.
86
+
87
+ NOTE Removing 'DQ' causes a sementic change: f32 -> quantized
88
+
89
+ [BEFORE]
90
+ W (quantized) - Dequantize (float)
91
+
92
+ [AFTER]
93
+ W (quantized)
94
+ """
95
+
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def call(self, exported_program: ExportedProgram) -> PassResult:
100
+ logger = logging.getLogger(__name__)
101
+
102
+ graph_module = exported_program.graph_module
103
+ graph: torch.fx.Graph = graph_module.graph
104
+ for dq in graph.nodes:
105
+ if not dq.op == "call_function":
106
+ continue
107
+
108
+ if dq.target not in [
109
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
110
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
111
+ ]:
112
+ continue
113
+ dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
114
+
115
+ if (
116
+ dq.target
117
+ == torch.ops.quantized_decomposed.dequantize_per_channel.default
118
+ ):
119
+ dq_args = DequantizePerChannelArgs(*dq.args, **dq.kwargs)
120
+ elif (
121
+ dq.target
122
+ == torch.ops.quantized_decomposed.dequantize_per_tensor.default
123
+ ):
124
+ dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
125
+ else:
126
+ raise RuntimeError(f"Invalid DQ target: {dq.target}")
127
+
128
+ q_weight = dq_args.input
129
+ # All weights are placehoders.
130
+ if q_weight.op != "placeholder":
131
+ continue
132
+ # Check if DQ already has quant param because DQ can be shared.
133
+ if QPARAM_KEY in q_weight.meta:
134
+ continue
135
+
136
+ q_weight_meta = q_weight.meta["val"]
137
+ assert isinstance(q_weight_meta, FakeTensor)
138
+ # Weight should have quantized values.
139
+ assert q_weight_meta.dtype != torch.float
140
+
141
+ q_weight_val = get_constant(exported_program, q_weight)
142
+ assert isinstance(q_weight_val, torch.Tensor)
143
+
144
+ quant_param = QuantParam()
145
+ if isinstance(dq_args, DequantizePerChannelArgs):
146
+ scales = get_constant(exported_program, dq_args.scales)
147
+ zero_ps = get_constant(exported_program, dq_args.zero_points)
148
+
149
+ # Sometimes users can give fp32 zero point. Let's update dtype here.
150
+ zero_ps = zero_ps.to(torch.int64)
151
+ quant_param.scale = scales.tolist()
152
+ quant_param.zero_point = zero_ps.tolist()
153
+ assert quant_param.zero_point is not None # To avoid mypy error
154
+ quant_param.quantized_dimension = dq_args.axis
155
+ quant_param.dtype = infer_dtype(
156
+ q_weight_val, quant_param.zero_point, q_weight_meta.dtype
157
+ )
158
+ elif isinstance(dq_args, DequantizePerTensorArgs):
159
+ quant_param.scale = [dq_args.scale]
160
+ quant_param.zero_point = [dq_args.zero_point]
161
+ assert quant_param.zero_point is not None # To avoid mypy error
162
+ quant_param.dtype = infer_dtype(
163
+ q_weight_val, quant_param.zero_point, q_weight_meta.dtype
164
+ )
165
+ else:
166
+ raise RuntimeError(f"Invalid DQ target: {dq.target}")
167
+
168
+ q_weight.meta[QPARAM_KEY] = quant_param
169
+ dq.replace_all_uses_with(q_weight, propagate_meta=False)
170
+ logger.debug(f"{dq.name} is removed.")
171
+
172
+ graph.eliminate_dead_code()
173
+ graph.lint()
174
+ graph_module.recompile()
175
+
176
+ # Run only once.
177
+ return PassResult(False)
@@ -0,0 +1,108 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ from typing import Any, Dict, Optional, Type
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
+ from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
+ from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
+ SmoothQuantQuantizer,
24
+ )
25
+ from tico.experimental.quantization.config import BaseConfig
26
+ from tico.experimental.quantization.quantizer import BaseQuantizer
27
+
28
+
29
+ config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
30
+ "pt2e": PT2EQuantizer,
31
+ "gptq": GPTQQuantizer,
32
+ "smooth_quant": SmoothQuantQuantizer,
33
+ }
34
+
35
+ QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
36
+
37
+
38
+ def prepare(
39
+ model: torch.nn.Module,
40
+ quant_config: BaseConfig,
41
+ args: Optional[Any] = None,
42
+ kwargs: Optional[Dict[str, Any]] = None,
43
+ inplace: Optional[bool] = False,
44
+ ):
45
+ """
46
+ Prepare the model for quantization using the provided configuration.
47
+
48
+ Determines the appropriate quantizer based on the type of `quant_config` and
49
+ prepares the model accordingly.
50
+
51
+ Parameters:
52
+ model: The PyTorch model to be quantized.
53
+ quant_config (BaseConfig): The quantization configuration.
54
+ args (Any, optional): Positional example inputs required for activation quantization.
55
+ kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
56
+ inplace (bool, optional): If true, the model will be modified in place;
57
+ otherwise, a new prepared model is returned.
58
+
59
+ Returns:
60
+ The model prepared for quantization.
61
+ """
62
+ if quant_config.name == "pt2e" and inplace:
63
+ raise RuntimeError(
64
+ "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
65
+ )
66
+
67
+ model = model if inplace else copy.deepcopy(model)
68
+
69
+ quantizer = config_to_quantizer[quant_config.name](quant_config)
70
+ model = quantizer.prepare(model, args, kwargs)
71
+ setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
72
+
73
+ return model
74
+
75
+
76
+ def convert(model, inplace: Optional[bool] = False):
77
+ """
78
+ Convert the prepared model to a quantized model using the provided configuration.
79
+
80
+ Determines the appropriate quantizer based on the type of quant_config and
81
+ converts the model accordingly.
82
+
83
+ Parameters:
84
+ model: The prepared PyTorch model.
85
+ inplace (bool, optional): If true, the model will be modified in place;
86
+ otherwise, a new prepared model is returned.
87
+
88
+ Returns:
89
+ The quantized model.
90
+ """
91
+ # Get quantizer first before calling deepcopy that does not copy attributes properly.
92
+ if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
93
+ quantizer = getattr(model, QUANTIZER_ATTRIBUTE_NAME)
94
+ delattr(model, QUANTIZER_ATTRIBUTE_NAME)
95
+ else:
96
+ raise RuntimeError("Call prepare() function first.")
97
+
98
+ if isinstance(quantizer, PT2EQuantizer) and inplace:
99
+ raise RuntimeError(
100
+ "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
101
+ )
102
+
103
+ model = model if inplace else copy.deepcopy(model)
104
+
105
+ assert isinstance(quantizer, BaseQuantizer)
106
+ model = quantizer.convert(model)
107
+
108
+ return model