tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from circle_schema import circle
21
+ from numpy.typing import DTypeLike
22
+
23
+ from tico.utils import logging
24
+ from tico.utils.model import CircleModel
25
+
26
+
27
+ def quantize(
28
+ data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
29
+ ) -> np.ndarray:
30
+ """
31
+ Quantize the given data using the specified scale, zero point, and data type.
32
+ This function takes input data and applies quantization using the formula:
33
+ round(data / scale) + zero_point
34
+ The result is clamped to the range of the specified data type.
35
+ """
36
+ logger = logging.getLogger(__name__)
37
+ dtype = np.dtype(dtype)
38
+ assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
39
+ if dtype == np.int16:
40
+ assert zero_point == 0
41
+
42
+ # Convert input to Numpy array if necessary
43
+ if not isinstance(data, np.ndarray):
44
+ data = np.array(data)
45
+ # Perfrom quantization
46
+ if not scale:
47
+ logger.warn("WARNING: scale value is 0. 1e-7 will be used instead.")
48
+ scale = 1e-7
49
+ rescaled = np.round(data / scale) + zero_point
50
+ # Clamp the values
51
+ clipped = np.clip(rescaled, np.iinfo(dtype).min, np.iinfo(dtype).max)
52
+ # Convert to the specified dtype
53
+ return clipped.astype(dtype)
54
+
55
+
56
+ def dequantize(
57
+ data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
58
+ ) -> np.ndarray:
59
+ """
60
+ Dequantize the given quantized data using the specified scale and zero point.
61
+ This function reverses the quantization process by applying the formula:
62
+ (quantized_value - zero_point) * scale
63
+ """
64
+ dtype = np.dtype(dtype)
65
+ assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
66
+ if dtype == np.int16:
67
+ assert zero_point == 0
68
+
69
+ # Convert input to Numpy array if necessary
70
+ if not isinstance(data, np.ndarray):
71
+ data = np.array(data)
72
+ # Perform dequantization
73
+ ret = (data.astype(np.float32) - zero_point) * scale
74
+ # np.float32 * np.int64 = np.float64
75
+ return ret.astype(np.float32)
76
+
77
+
78
+ def get_graph_input_output(
79
+ circle_model: CircleModel,
80
+ ) -> Tuple[List[circle.Tensor.Tensor], List[circle.Tensor.Tensor]]:
81
+ """
82
+ Retrieve the inputs and the outputs from the circle model, and return them
83
+ as two lists.
84
+ """
85
+ circle_buf: bytes = circle_model.circle_binary
86
+ circle_fb: circle.Model.Model = circle.Model.Model.GetRootAs(circle_buf, 0)
87
+ assert circle_fb.SubgraphsLength() == 1, "Only support single graph."
88
+ circle_graph = circle_fb.Subgraphs(0)
89
+ circle_inputs: List[circle.Tensor.Tensor] = [
90
+ circle_graph.Tensors(circle_graph.Inputs(i))
91
+ for i in range(circle_graph.InputsLength())
92
+ ]
93
+ circle_outputs: List[circle.Tensor.Tensor] = [
94
+ circle_graph.Tensors(circle_graph.Outputs(o))
95
+ for o in range(circle_graph.OutputsLength())
96
+ ]
97
+
98
+ return circle_inputs, circle_outputs
99
+
100
+
101
+ def find_invalid_types(
102
+ input: List[torch.Tensor] | List[np.ndarray], allowed_types: List
103
+ ) -> List:
104
+ """
105
+ Indentifies the types of items in a list that are not allowed and removes duplicates.
106
+
107
+ Parameters
108
+ -----------
109
+ input
110
+ List of itmes to check.
111
+ allowed_types
112
+ List of allowed types (e.g. [int, str])
113
+ Returns
114
+ --------
115
+ A list of unique types that are not allowed in the input list.
116
+ """
117
+ # Use set comprehension for uniqueness
118
+ invalid_types = {
119
+ type(item) for item in input if not isinstance(item, tuple(allowed_types))
120
+ }
121
+ return list(invalid_types)
122
+
123
+
124
+ def plot_two_outputs(x_values: torch.Tensor, y_values: torch.Tensor):
125
+ """
126
+ Plot two values on a 2D graph using plotext.
127
+
128
+ Returns
129
+ --------
130
+ A figure built from plotext.
131
+
132
+ Example
133
+ --------
134
+ >>> x_values = torch.tensor([1, 2, 3, 4, 5])
135
+ >>> y_values = torch.tensor([10, 20, 30, 40, 50])
136
+ >>> fig = plot_two_outputs(x_values, y_values)
137
+ >>> print(fig)
138
+ """
139
+ x_np = x_values.numpy().reshape(-1)
140
+ y_np = y_values.numpy().reshape(-1)
141
+ min_value = min([x_np.min(), y_np.min()])
142
+ max_value = max([x_np.max(), y_np.max()])
143
+
144
+ interval = max_value - min_value
145
+ interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
146
+
147
+ # Enlarge axis
148
+ axis_min = min_value - interval * 0.05
149
+ axis_max = max_value + interval * 0.05
150
+
151
+ import plotext as plt
152
+
153
+ plt.clear_data()
154
+ plt.xlim(axis_min, axis_max)
155
+ plt.ylim(axis_min, axis_max)
156
+ plt.plotsize(width=50, height=25)
157
+ plt.scatter(x_np, y_np, marker="dot")
158
+ plt.theme("clear")
159
+
160
+ return plt.build()
161
+
162
+
163
+ def ensure_list(inputs: Any | Tuple[Any] | List[Any]) -> List[Any]:
164
+ """
165
+ Ensures that the given inputs is converted into a list.
166
+
167
+ - If the input is a single element, it wraps it into a list.
168
+ - If the input is a tuple, it converts the tuple to a list.
169
+ - If the input is already a list, it returns the input unchanged.
170
+
171
+ Example
172
+ --------
173
+ >>> ensure_list(42)
174
+ >>> [42]
175
+ >>> ensure_list((1, 2, 3))
176
+ >>> [1, 2, 3]
177
+ >>> ensure_list([4, 5, 6])
178
+ >>> [4, 5, 6]
179
+ """
180
+ if isinstance(inputs, list):
181
+ return inputs
182
+ elif isinstance(inputs, tuple):
183
+ return list(inputs)
184
+ else:
185
+ return [inputs]
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,97 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.validate_args_kwargs import (
27
+ DequantizePerTensorArgs,
28
+ QuantizePerTensorArgs,
29
+ )
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class FoldQuantOps(PassBase):
34
+ """
35
+ This pass folds (Q - DQ) pattern to previous op. After quantization from torch, activation ops
36
+ have (op - Q - DQ) pattern.
37
+
38
+ To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
39
+ to previous op's metadata.
40
+
41
+ [BEFORE]
42
+ op (float) - Quantize - Dequantize - (float)
43
+
44
+ [AFTER]
45
+ op (float with meta[QPARAM_KEY])
46
+ """
47
+
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def call(self, exported_program: ExportedProgram) -> PassResult:
52
+ logger = logging.getLogger(__name__)
53
+
54
+ graph_module = exported_program.graph_module
55
+ graph: torch.fx.Graph = graph_module.graph
56
+ for dq in graph.nodes:
57
+ if dq.op != "call_function":
58
+ continue
59
+ if (
60
+ dq.target
61
+ != torch.ops.quantized_decomposed.dequantize_per_tensor.default
62
+ ):
63
+ continue
64
+ dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
65
+
66
+ q = dq_args.input
67
+ if q.target != torch.ops.quantized_decomposed.quantize_per_tensor.default:
68
+ continue
69
+ q_args = QuantizePerTensorArgs(*q.args, **q.kwargs) # type: ignore[arg-type]
70
+ op = q_args.tensor
71
+
72
+ # Check if Q and DQ have same quant param
73
+ if q_args.scale != dq_args.scale:
74
+ continue
75
+ if q_args.zero_p != dq_args.zero_point:
76
+ continue
77
+ if q_args.dtype != dq_args.dtype:
78
+ continue
79
+
80
+ if QPARAM_KEY not in op.meta:
81
+ qparam = QuantParam()
82
+ qparam.scale = [q_args.scale]
83
+ qparam.zero_point = [q_args.zero_p]
84
+ assert "val" in q.meta and hasattr(q.meta["val"], "dtype")
85
+ qparam.dtype = to_qparam_dtype(q.meta["val"].dtype)
86
+ op.meta[QPARAM_KEY] = qparam
87
+
88
+ dq.replace_all_uses_with(op, propagate_meta=False)
89
+
90
+ logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
91
+
92
+ graph.eliminate_dead_code()
93
+ graph.lint()
94
+ graph_module.recompile()
95
+
96
+ # Run only once.
97
+ return PassResult(False)
@@ -0,0 +1,289 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
25
+ from tico.utils import logging
26
+ from tico.utils.errors import NotYetSupportedError
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import quant_min_max, set_new_meta_val
30
+ from tico.utils.validate_args_kwargs import (
31
+ BmmArgs,
32
+ LinearArgs,
33
+ MulTensorArgs,
34
+ PermuteArgs,
35
+ )
36
+
37
+
38
+ def qparam_dtype(node: torch.fx.Node) -> str:
39
+ assert QPARAM_KEY in node.meta
40
+ return node.meta[QPARAM_KEY].dtype
41
+
42
+
43
+ # Convert i16 qparam to u8 qparam
44
+ # scale and zero_point are inferred from i16 qparam
45
+ def _i16_to_u8(qparam: QuantParam) -> QuantParam:
46
+ # Assume per-tensor quantization
47
+ assert qparam.scale is not None and len(qparam.scale) == 1
48
+ assert qparam.dtype == "int16"
49
+
50
+ s16_scale = qparam.scale[0]
51
+ max_ = s16_scale * 32767 # numeric_limits<int16>
52
+ min_ = -max_
53
+
54
+ u8_scale = (max_ - min_) / 255
55
+ u8_zerop = round(-min_ / u8_scale)
56
+
57
+ new_qparam = QuantParam()
58
+ new_qparam.scale = [u8_scale]
59
+ new_qparam.zero_point = [u8_zerop]
60
+ new_qparam.dtype = "uint8"
61
+
62
+ return new_qparam
63
+
64
+
65
+ # Convert u8 qparam to i16 qparam
66
+ # scale is inferred from u8 qparam
67
+ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
68
+ # Assume per-tensor quantization
69
+ assert qparam.scale is not None and len(qparam.scale) == 1
70
+ assert qparam.zero_point is not None and len(qparam.zero_point) == 1
71
+ assert qparam.dtype == "uint8"
72
+
73
+ u8_scale = qparam.scale[0]
74
+ u8_zerop = qparam.zero_point[0]
75
+ max_ = u8_scale * (255 - u8_zerop)
76
+ min_ = u8_scale * (-u8_zerop)
77
+
78
+ abs_max = max([max_, min_], key=abs)
79
+ s16_scale = abs_max / 32767
80
+ s16_zerop = 0
81
+
82
+ new_qparam = QuantParam()
83
+ new_qparam.scale = [s16_scale]
84
+ new_qparam.zero_point = [s16_zerop]
85
+ new_qparam.dtype = "int16"
86
+
87
+ return new_qparam
88
+
89
+
90
+ @trace_graph_diff_on_pass
91
+ class InsertQuantizeOnDtypeMismatch(PassBase):
92
+ """
93
+ Insert quantize Op in the operators where circle's type inference is violated.
94
+ Example. FullyConnected
95
+ [BEFORE]
96
+ Op (uint8) - aten.linear.default (int16)
97
+ [AFTER]
98
+ Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
99
+ Why is this pass necessary?
100
+ - For some operators, circle's type inference pass overwrites the input's dtype to
101
+ the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
102
+ output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
103
+ This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
104
+ - To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
105
+ """
106
+
107
+ def __init__(self):
108
+ super().__init__()
109
+
110
+ def call(self, exported_program: ExportedProgram) -> PassResult:
111
+ logger = logging.getLogger(__name__)
112
+
113
+ graph_module = exported_program.graph_module
114
+ graph: torch.fx.Graph = graph_module.graph
115
+
116
+ def _insert_quantize_op_before(node, inp):
117
+ qparam: QuantParam = node.meta[QPARAM_KEY]
118
+ assert qparam.scale is not None
119
+ assert qparam.zero_point is not None
120
+ scale = qparam.scale[0]
121
+ zerop = qparam.zero_point[0]
122
+ min_, max_ = quant_min_max(qparam.dtype)
123
+ dtype = getattr(torch, qparam.dtype)
124
+
125
+ with graph.inserting_before(node):
126
+ q_args = (inp, scale, zerop, min_, max_, dtype)
127
+ quantize = graph.call_function(
128
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
129
+ args=q_args,
130
+ )
131
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
132
+ set_new_meta_val(quantize)
133
+
134
+ node.replace_input_with(inp, quantize)
135
+
136
+ return quantize
137
+
138
+ def _insert_quantize_op_after(node):
139
+ qparam: QuantParam = node.meta[QPARAM_KEY]
140
+ assert qparam.scale is not None
141
+ assert qparam.zero_point is not None
142
+ scale = qparam.scale[0]
143
+ zerop = qparam.zero_point[0]
144
+ min_, max_ = quant_min_max(qparam.dtype)
145
+ dtype = getattr(torch, qparam.dtype)
146
+ with graph.inserting_after(node):
147
+ q_args = (node, scale, zerop, min_, max_, dtype)
148
+ quantize = graph.call_function(
149
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
150
+ args=q_args,
151
+ )
152
+
153
+ node.replace_all_uses_with(quantize, propagate_meta=True)
154
+ quantize.replace_input_with(quantize, node)
155
+
156
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
157
+
158
+ return quantize
159
+
160
+ for node in graph.nodes:
161
+ if node.op != "call_function":
162
+ continue
163
+ if node.target == torch.ops.aten.linear.default:
164
+ lin_args = LinearArgs(*node.args, **node.kwargs)
165
+ inp = lin_args.input
166
+
167
+ if QPARAM_KEY not in inp.meta:
168
+ continue
169
+
170
+ if QPARAM_KEY not in node.meta:
171
+ continue
172
+
173
+ if qparam_dtype(inp) == qparam_dtype(node):
174
+ continue
175
+
176
+ if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
177
+ quantize = _insert_quantize_op_after(node)
178
+
179
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
180
+
181
+ # Update node's qparam from i16 to u8
182
+ # NOTE This would severely degrade accuracy. It is
183
+ # important to mitigate this accuracy drop in backend.
184
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
185
+ logger.debug(
186
+ f"quantize_per_tensor.default is inserted after {node.name}."
187
+ )
188
+ else:
189
+ raise NotYetSupportedError("Unsupported dtype")
190
+
191
+ elif node.target == torch.ops.aten.mul.Tensor:
192
+ mul_args = MulTensorArgs(*node.args, **node.kwargs)
193
+ x = mul_args.input
194
+ y = mul_args.other
195
+
196
+ if not isinstance(x, torch.fx.Node):
197
+ continue
198
+ if not isinstance(y, torch.fx.Node):
199
+ continue
200
+
201
+ if QPARAM_KEY not in x.meta:
202
+ continue
203
+ if QPARAM_KEY not in y.meta:
204
+ continue
205
+ if QPARAM_KEY not in node.meta:
206
+ continue
207
+
208
+ if qparam_dtype(x) == qparam_dtype(node):
209
+ continue
210
+
211
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
212
+ quantize = _insert_quantize_op_after(node)
213
+
214
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
215
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
216
+ logger.debug(
217
+ f"quantize_per_tensor.default is inserted after {node.name}."
218
+ )
219
+ else:
220
+ raise NotYetSupportedError("Unsupported dtype")
221
+
222
+ elif node.target == torch.ops.aten.bmm.default:
223
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
224
+ x = bmm_args.input
225
+ y = bmm_args.mat2
226
+
227
+ if QPARAM_KEY not in x.meta:
228
+ continue
229
+ if QPARAM_KEY not in y.meta:
230
+ continue
231
+ if QPARAM_KEY not in node.meta:
232
+ continue
233
+
234
+ if qparam_dtype(x) == qparam_dtype(node):
235
+ continue
236
+
237
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
238
+ quantize = _insert_quantize_op_after(node)
239
+
240
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
241
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
242
+ logger.debug(
243
+ f"quantize_per_tensor.default is inserted after {node.name}."
244
+ )
245
+ else:
246
+ raise NotYetSupportedError("Unsupported dtype")
247
+
248
+ elif node.target == torch.ops.aten.permute.default:
249
+ per_args = PermuteArgs(*node.args, **node.kwargs)
250
+ inp = per_args.input
251
+
252
+ if QPARAM_KEY not in inp.meta:
253
+ continue
254
+
255
+ if QPARAM_KEY not in node.meta:
256
+ continue
257
+
258
+ if qparam_dtype(inp) == qparam_dtype(node):
259
+ continue
260
+
261
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
262
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
263
+ # permute Op to reduce tensor size ealier
264
+ quantize = _insert_quantize_op_before(node, inp)
265
+
266
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
267
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
268
+ logger.debug(
269
+ f"quantize_per_tensor.default is inserted before {node.name}."
270
+ )
271
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
272
+ quantize = _insert_quantize_op_after(node)
273
+
274
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
275
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
276
+ logger.debug(
277
+ f"quantize_per_tensor.default is inserted after {node.name}."
278
+ )
279
+ else:
280
+ raise NotYetSupportedError("Unsupported dtype")
281
+
282
+ # TODO Support more ops.
283
+
284
+ graph.eliminate_dead_code()
285
+ graph.lint()
286
+ graph_module.recompile()
287
+
288
+ # Run only once.
289
+ return PassResult(False)
@@ -0,0 +1,91 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import CatArgs, PermuteArgs, ReshapeArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class PropagateQParamBackward(PassBase):
33
+ """
34
+ This pass propagates quantization parameters backward.
35
+
36
+ BEFORE)
37
+
38
+ node -> reshape (with meta[QPARAM_KEY])
39
+
40
+ AFTER)
41
+
42
+ node (with meta[QPARAM_KEY]) -> reshape (with meta[QPARAM_KEY])
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ def call(self, exported_program: ExportedProgram) -> PassResult:
49
+ logger = logging.getLogger(__name__)
50
+
51
+ graph_module = exported_program.graph_module
52
+ graph: torch.fx.Graph = graph_module.graph
53
+
54
+ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
55
+ if QPARAM_KEY not in src.meta:
56
+ return
57
+
58
+ if (
59
+ QPARAM_KEY in dst.meta
60
+ and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
61
+ ):
62
+ return
63
+
64
+ dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
65
+
66
+ logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
67
+
68
+ # Do reverse-order traversal for backward propagation
69
+ for node in reversed(graph.nodes):
70
+ if node.op != "call_function":
71
+ continue
72
+ if node.target == torch.ops.aten.cat.default:
73
+ concat_args = CatArgs(*node.args, **node.kwargs)
74
+ concat_inputs = concat_args.tensors
75
+
76
+ for concat_input in concat_inputs:
77
+ _propagate_qparam_if_possible(node, concat_input)
78
+ elif node.target == torch.ops.aten.reshape.default:
79
+ args = ReshapeArgs(*node.args, **node.kwargs)
80
+ _propagate_qparam_if_possible(node, args.input)
81
+ elif node.target == torch.ops.aten.permute.default:
82
+ permute_args = PermuteArgs(*node.args, **node.kwargs)
83
+ _propagate_qparam_if_possible(node, permute_args.input)
84
+ # TODO Support more ops.
85
+
86
+ graph.eliminate_dead_code()
87
+ graph.lint()
88
+ graph_module.recompile()
89
+
90
+ # Run only once.
91
+ return PassResult(False)