tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from circle_schema import circle
21
+ from numpy.typing import DTypeLike
22
+
23
+ from tico.utils import logging
24
+ from tico.utils.model import CircleModel
25
+
26
+
27
+ def quantize(
28
+ data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
29
+ ) -> np.ndarray:
30
+ """
31
+ Quantize the given data using the specified scale, zero point, and data type.
32
+ This function takes input data and applies quantization using the formula:
33
+ round(data / scale) + zero_point
34
+ The result is clamped to the range of the specified data type.
35
+ """
36
+ logger = logging.getLogger(__name__)
37
+ dtype = np.dtype(dtype)
38
+ assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
39
+ if dtype == np.int16:
40
+ assert zero_point == 0
41
+
42
+ # Convert input to Numpy array if necessary
43
+ if not isinstance(data, np.ndarray):
44
+ data = np.array(data)
45
+ # Perfrom quantization
46
+ if not scale:
47
+ logger.warn("WARNING: scale value is 0. 1e-7 will be used instead.")
48
+ scale = 1e-7
49
+ rescaled = np.round(data / scale) + zero_point
50
+ # Clamp the values
51
+ clipped = np.clip(rescaled, np.iinfo(dtype).min, np.iinfo(dtype).max)
52
+ # Convert to the specified dtype
53
+ return clipped.astype(dtype)
54
+
55
+
56
+ def dequantize(
57
+ data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
58
+ ) -> np.ndarray:
59
+ """
60
+ Dequantize the given quantized data using the specified scale and zero point.
61
+ This function reverses the quantization process by applying the formula:
62
+ (quantized_value - zero_point) * scale
63
+ """
64
+ dtype = np.dtype(dtype)
65
+ assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
66
+ if dtype == np.int16:
67
+ assert zero_point == 0
68
+
69
+ # Convert input to Numpy array if necessary
70
+ if not isinstance(data, np.ndarray):
71
+ data = np.array(data)
72
+ # Perform dequantization
73
+ ret = (data.astype(np.float32) - zero_point) * scale
74
+ # np.float32 * np.int64 = np.float64
75
+ return ret.astype(np.float32)
76
+
77
+
78
+ def get_graph_input_output(
79
+ circle_model: CircleModel,
80
+ ) -> Tuple[List[circle.Tensor.Tensor], List[circle.Tensor.Tensor]]:
81
+ """
82
+ Retrieve the inputs and the outputs from the circle model, and return them
83
+ as two lists.
84
+ """
85
+ circle_buf: bytes = circle_model.circle_binary
86
+ circle_fb: circle.Model.Model = circle.Model.Model.GetRootAs(circle_buf, 0)
87
+ assert circle_fb.SubgraphsLength() == 1, "Only support single graph."
88
+ circle_graph = circle_fb.Subgraphs(0)
89
+ circle_inputs: List[circle.Tensor.Tensor] = [
90
+ circle_graph.Tensors(circle_graph.Inputs(i))
91
+ for i in range(circle_graph.InputsLength())
92
+ ]
93
+ circle_outputs: List[circle.Tensor.Tensor] = [
94
+ circle_graph.Tensors(circle_graph.Outputs(o))
95
+ for o in range(circle_graph.OutputsLength())
96
+ ]
97
+
98
+ return circle_inputs, circle_outputs
99
+
100
+
101
+ def find_invalid_types(
102
+ input: List[torch.Tensor] | List[np.ndarray], allowed_types: List
103
+ ) -> List:
104
+ """
105
+ Indentifies the types of items in a list that are not allowed and removes duplicates.
106
+
107
+ Parameters
108
+ -----------
109
+ input
110
+ List of itmes to check.
111
+ allowed_types
112
+ List of allowed types (e.g. [int, str])
113
+ Returns
114
+ --------
115
+ A list of unique types that are not allowed in the input list.
116
+ """
117
+ # Use set comprehension for uniqueness
118
+ invalid_types = {
119
+ type(item) for item in input if not isinstance(item, tuple(allowed_types))
120
+ }
121
+ return list(invalid_types)
122
+
123
+
124
+ def plot_two_outputs(x_values: torch.Tensor, y_values: torch.Tensor):
125
+ """
126
+ Plot two values on a 2D graph using plotext.
127
+
128
+ Returns
129
+ --------
130
+ A figure built from plotext.
131
+
132
+ Example
133
+ --------
134
+ >>> x_values = torch.tensor([1, 2, 3, 4, 5])
135
+ >>> y_values = torch.tensor([10, 20, 30, 40, 50])
136
+ >>> fig = plot_two_outputs(x_values, y_values)
137
+ >>> print(fig)
138
+ """
139
+ x_np = x_values.numpy().reshape(-1)
140
+ y_np = y_values.numpy().reshape(-1)
141
+ min_value = min([x_np.min(), y_np.min()])
142
+ max_value = max([x_np.max(), y_np.max()])
143
+
144
+ interval = max_value - min_value
145
+ interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
146
+
147
+ # Enlarge axis
148
+ axis_min = min_value - interval * 0.05
149
+ axis_max = max_value + interval * 0.05
150
+
151
+ import plotext as plt
152
+
153
+ plt.clear_data()
154
+ plt.xlim(axis_min, axis_max)
155
+ plt.ylim(axis_min, axis_max)
156
+ plt.plotsize(width=50, height=25)
157
+ plt.scatter(x_np, y_np, marker="dot")
158
+ plt.theme("clear")
159
+
160
+ return plt.build()
161
+
162
+
163
+ def ensure_list(inputs: Any | Tuple[Any] | List[Any]) -> List[Any]:
164
+ """
165
+ Ensures that the given inputs is converted into a list.
166
+
167
+ - If the input is a single element, it wraps it into a list.
168
+ - If the input is a tuple, it converts the tuple to a list.
169
+ - If the input is already a list, it returns the input unchanged.
170
+
171
+ Example
172
+ --------
173
+ >>> ensure_list(42)
174
+ >>> [42]
175
+ >>> ensure_list((1, 2, 3))
176
+ >>> [1, 2, 3]
177
+ >>> ensure_list([4, 5, 6])
178
+ >>> [4, 5, 6]
179
+ """
180
+ if isinstance(inputs, list):
181
+ return inputs
182
+ elif isinstance(inputs, tuple):
183
+ return list(inputs)
184
+ else:
185
+ return [inputs]
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,154 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import get_quant_dtype
29
+ from tico.utils.validate_args_kwargs import (
30
+ DequantizePerTensorArgs,
31
+ QuantizePerTensorArgs,
32
+ )
33
+
34
+
35
+ @trace_graph_diff_on_pass
36
+ class FoldQuantOps(PassBase):
37
+ """
38
+ This pass folds (Q - DQ) pattern to previous op. After quantization from torch, activation ops
39
+ have (op - Q - DQ) pattern.
40
+
41
+ To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
42
+ to previous op's metadata.
43
+
44
+ ────────────────────────────────────────────────────────────────
45
+ BEFORE AFTER
46
+ ────────────────────────────────────────────────────────────────
47
+ op(float) ─ Q ─ DQ ─ … op(float, meta[QPARAM])
48
+
49
+ op ─ Q1 ─ DQ1 ─ Q2 ─ DQ2 op(meta[QPARAM]) ─ Q2
50
+ ▲ ▲
51
+ │ (Q1, DQ1 folded) │ (re-quantization kept)
52
+
53
+ op ─ Q ─┬─ DQ0 op(meta[QPARAM])
54
+ ├─ DQ1 (each DQ* folded, Q dropped when orphaned)
55
+ └─ DQ2
56
+ ────────────────────────────────────────────────────────────────
57
+
58
+ Algorithm
59
+ ---------
60
+ 1. Iterate over *all* Dequantize nodes.
61
+ 2. For each DQ, verify it is driven by a Quantize node `q` and that
62
+ `q` and `dq` share identical (scale, zero-point, dtype).
63
+ 3. a) If the producer op has **no** QPARAM, attach one, then replace
64
+ *this* DQ's usages with the producer op.
65
+ b) If the producer is already quantized with a different dtype,
66
+ this is a *re-quantization*: attach QPARAM to `q` and keep it,
67
+ but still remove the DQ.
68
+ 4. After all replacements, run `graph.eliminate_dead_code()`.
69
+ Any Quantize that became orphaned because *all* its DQs were folded
70
+ is deleted automatically.
71
+ """
72
+
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def call(self, exported_program: ExportedProgram) -> PassResult:
77
+ logger = logging.getLogger(__name__)
78
+
79
+ graph_module = exported_program.graph_module
80
+ graph: torch.fx.Graph = graph_module.graph
81
+ for dq in graph.nodes:
82
+ if dq.op != "call_function":
83
+ continue
84
+ if (
85
+ dq.target
86
+ != torch.ops.quantized_decomposed.dequantize_per_tensor.default
87
+ ):
88
+ continue
89
+ dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
90
+
91
+ q = dq_args.input
92
+ if q.target != torch.ops.quantized_decomposed.quantize_per_tensor.default:
93
+ continue
94
+ q_args = QuantizePerTensorArgs(*q.args, **q.kwargs) # type: ignore[arg-type]
95
+ op = q_args.tensor
96
+
97
+ # Check if Q and DQ have same quant param
98
+ if q_args.scale != dq_args.scale:
99
+ continue
100
+ if q_args.zero_p != dq_args.zero_point:
101
+ continue
102
+ if q_args.dtype != dq_args.dtype:
103
+ continue
104
+
105
+ # ───────────────────────────────────────────
106
+ # Case 1: op not yet quantized
107
+ # ───────────────────────────────────────────
108
+ if QPARAM_KEY not in op.meta:
109
+ qparam = QuantParam()
110
+ qparam.scale = [q_args.scale]
111
+ qparam.zero_point = [q_args.zero_p]
112
+ qparam.dtype = get_quant_dtype(q_args.quant_min, q_args.quant_max)
113
+ op.meta[QPARAM_KEY] = qparam
114
+
115
+ dq.replace_all_uses_with(op, propagate_meta=False)
116
+
117
+ logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
118
+ # ───────────────────────────────────────────
119
+ # Case 2: op already quantized
120
+ # 2.1 same dtype → nothing to do
121
+ # 2.2 diff dtype → leave Q in place
122
+ # ───────────────────────────────────────────
123
+ else:
124
+ op_qparam: QuantParam = op.meta[QPARAM_KEY]
125
+ qdq_dtype = get_quant_dtype(q_args.quant_min, q_args.quant_max)
126
+
127
+ if op_qparam.dtype != qdq_dtype:
128
+ # Attach QPARAM to Q once
129
+ if QPARAM_KEY not in q.meta:
130
+ qparam = QuantParam()
131
+ qparam.scale = [q_args.scale]
132
+ qparam.zero_point = [q_args.zero_p]
133
+ qparam.dtype = qdq_dtype
134
+ q.meta[QPARAM_KEY] = qparam
135
+ assert len(q.users) == 1, "Fix me unless"
136
+
137
+ dq.replace_all_uses_with(q, propagate_meta=False)
138
+ logger.debug(f"{dq.name} is folded ({q.name} is left).")
139
+ else:
140
+ # Same dtype → the Quantize–Dequantize pair is redundant.
141
+ assert op_qparam.scale and op_qparam.scale[0] == q_args.scale
142
+ assert (
143
+ op_qparam.zero_point
144
+ and op_qparam.zero_point[0] == q_args.zero_p
145
+ )
146
+ dq.replace_all_uses_with(op, propagate_meta=False)
147
+ logger.debug(f"Removed redundant {dq.name}")
148
+
149
+ graph.eliminate_dead_code()
150
+ graph.lint()
151
+ graph_module.recompile()
152
+
153
+ # Run only once.
154
+ return PassResult(False)
@@ -0,0 +1,345 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
25
+ from tico.utils import logging
26
+ from tico.utils.errors import NotYetSupportedError
27
+ from tico.utils.graph import create_node
28
+ from tico.utils.passes import PassBase, PassResult
29
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
30
+ from tico.utils.utils import quant_min_max, set_new_meta_val
31
+ from tico.utils.validate_args_kwargs import (
32
+ BmmArgs,
33
+ LinearArgs,
34
+ MulTensorArgs,
35
+ PermuteArgs,
36
+ ReshapeArgs,
37
+ )
38
+
39
+
40
+ def qparam_dtype(node: torch.fx.Node) -> str:
41
+ assert QPARAM_KEY in node.meta
42
+ return node.meta[QPARAM_KEY].dtype
43
+
44
+
45
+ # Convert i16 qparam to u8 qparam
46
+ # scale and zero_point are inferred from i16 qparam
47
+ def _i16_to_u8(qparam: QuantParam) -> QuantParam:
48
+ # Assume per-tensor quantization
49
+ assert qparam.scale is not None and len(qparam.scale) == 1
50
+ assert qparam.dtype == "int16"
51
+
52
+ s16_scale = qparam.scale[0]
53
+ max_ = s16_scale * 32767 # numeric_limits<int16>
54
+ min_ = -max_
55
+
56
+ u8_scale = (max_ - min_) / 255
57
+ u8_zerop = round(-min_ / u8_scale)
58
+
59
+ new_qparam = QuantParam()
60
+ new_qparam.scale = [u8_scale]
61
+ new_qparam.zero_point = [u8_zerop]
62
+ new_qparam.dtype = "uint8"
63
+
64
+ return new_qparam
65
+
66
+
67
+ # Convert u8 qparam to i16 qparam
68
+ # scale is inferred from u8 qparam
69
+ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
70
+ # Assume per-tensor quantization
71
+ assert qparam.scale is not None and len(qparam.scale) == 1
72
+ assert qparam.zero_point is not None and len(qparam.zero_point) == 1
73
+ assert qparam.dtype == "uint8"
74
+
75
+ u8_scale = qparam.scale[0]
76
+ u8_zerop = qparam.zero_point[0]
77
+ max_ = u8_scale * (255 - u8_zerop)
78
+ min_ = u8_scale * (-u8_zerop)
79
+
80
+ abs_max = max([max_, min_], key=abs)
81
+ s16_scale = abs_max / 32767
82
+ s16_zerop = 0
83
+
84
+ new_qparam = QuantParam()
85
+ new_qparam.scale = [s16_scale]
86
+ new_qparam.zero_point = [s16_zerop]
87
+ new_qparam.dtype = "int16"
88
+
89
+ return new_qparam
90
+
91
+
92
+ @trace_graph_diff_on_pass
93
+ class InsertQuantizeOnDtypeMismatch(PassBase):
94
+ """
95
+ Insert quantize Op in the operators where circle's type inference is violated.
96
+ Example. FullyConnected
97
+ [BEFORE]
98
+ Op (uint8) - aten.linear.default (int16)
99
+ [AFTER]
100
+ Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
101
+ Why is this pass necessary?
102
+ - For some operators, circle's type inference pass overwrites the input's dtype to
103
+ the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
104
+ output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
105
+ This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
106
+ - To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
107
+ - NOTE For some cases, Quantize Op is inserted before the operators.
108
+
109
+ Let's assume Reshape Op's input is int16 and output is uint8. There are two possible places to insert
110
+ Quantize Op.
111
+
112
+ 1. Insert Quantize before Reshape.
113
+
114
+ ```
115
+ Predecessor (int16)-> Quantize (uint8) -> Reshape (uint8) -> ...
116
+ ```
117
+
118
+ 2. Insert Quantize after Reshape.
119
+
120
+ ```
121
+ Predecessor (int16)-> Reshape (int16) -> Quantize (uint8) -> ...
122
+ ```
123
+
124
+ Comparing 1) and 2), the difference is that Reshape operation is conducted in uint8 or int16.
125
+ We go with 1), which does Reshape in uint8, for faster execution. Note that Reshape Op does not
126
+ change the value, so its dytpe does not affect accuracy.
127
+ """
128
+
129
+ def __init__(self):
130
+ super().__init__()
131
+
132
+ def call(self, exported_program: ExportedProgram) -> PassResult:
133
+ logger = logging.getLogger(__name__)
134
+
135
+ graph_module = exported_program.graph_module
136
+ graph: torch.fx.Graph = graph_module.graph
137
+
138
+ def _insert_quantize_op_before(node, inp):
139
+ qparam: QuantParam = node.meta[QPARAM_KEY]
140
+ assert qparam.scale is not None
141
+ assert qparam.zero_point is not None
142
+ scale = qparam.scale[0]
143
+ zerop = qparam.zero_point[0]
144
+ min_, max_ = quant_min_max(qparam.dtype)
145
+ dtype = getattr(torch, qparam.dtype)
146
+
147
+ with graph.inserting_before(node):
148
+ q_args = (inp, scale, zerop, min_, max_, dtype)
149
+ quantize = create_node(
150
+ graph,
151
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
152
+ args=q_args,
153
+ origin=node,
154
+ )
155
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
156
+ set_new_meta_val(quantize)
157
+
158
+ node.replace_input_with(inp, quantize)
159
+
160
+ return quantize
161
+
162
+ def _insert_quantize_op_after(node):
163
+ qparam: QuantParam = node.meta[QPARAM_KEY]
164
+ assert qparam.scale is not None
165
+ assert qparam.zero_point is not None
166
+ scale = qparam.scale[0]
167
+ zerop = qparam.zero_point[0]
168
+ min_, max_ = quant_min_max(qparam.dtype)
169
+ dtype = getattr(torch, qparam.dtype)
170
+ with graph.inserting_after(node):
171
+ q_args = (node, scale, zerop, min_, max_, dtype)
172
+ quantize = create_node(
173
+ graph,
174
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
175
+ args=q_args,
176
+ )
177
+
178
+ node.replace_all_uses_with(quantize, propagate_meta=True)
179
+ quantize.replace_input_with(quantize, node)
180
+
181
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
182
+
183
+ return quantize
184
+
185
+ for node in graph.nodes:
186
+ if node.op != "call_function":
187
+ continue
188
+ if node.target == torch.ops.aten.linear.default:
189
+ lin_args = LinearArgs(*node.args, **node.kwargs)
190
+ inp = lin_args.input
191
+
192
+ if QPARAM_KEY not in inp.meta:
193
+ continue
194
+
195
+ if QPARAM_KEY not in node.meta:
196
+ continue
197
+
198
+ if qparam_dtype(inp) == qparam_dtype(node):
199
+ continue
200
+
201
+ if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
202
+ quantize = _insert_quantize_op_after(node)
203
+
204
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
205
+
206
+ # Update node's qparam from i16 to u8
207
+ # NOTE This would severely degrade accuracy. It is
208
+ # important to mitigate this accuracy drop in backend.
209
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
210
+ logger.debug(
211
+ f"quantize_per_tensor.default is inserted after {node.name}."
212
+ )
213
+ else:
214
+ raise NotYetSupportedError("Unsupported dtype")
215
+
216
+ elif node.target == torch.ops.aten.mul.Tensor:
217
+ mul_args = MulTensorArgs(*node.args, **node.kwargs)
218
+ x = mul_args.input
219
+ y = mul_args.other
220
+
221
+ if not isinstance(x, torch.fx.Node):
222
+ continue
223
+ if not isinstance(y, torch.fx.Node):
224
+ continue
225
+
226
+ if QPARAM_KEY not in x.meta:
227
+ continue
228
+ if QPARAM_KEY not in y.meta:
229
+ continue
230
+ if QPARAM_KEY not in node.meta:
231
+ continue
232
+
233
+ if qparam_dtype(x) == qparam_dtype(node):
234
+ continue
235
+
236
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
237
+ quantize = _insert_quantize_op_after(node)
238
+
239
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
240
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
241
+ logger.debug(
242
+ f"quantize_per_tensor.default is inserted after {node.name}."
243
+ )
244
+ else:
245
+ raise NotYetSupportedError("Unsupported dtype")
246
+
247
+ elif node.target == torch.ops.aten.bmm.default:
248
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
249
+ x = bmm_args.input
250
+ y = bmm_args.mat2
251
+
252
+ if QPARAM_KEY not in x.meta:
253
+ continue
254
+ if QPARAM_KEY not in y.meta:
255
+ continue
256
+ if QPARAM_KEY not in node.meta:
257
+ continue
258
+
259
+ if qparam_dtype(x) == qparam_dtype(node):
260
+ continue
261
+
262
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
263
+ quantize = _insert_quantize_op_after(node)
264
+
265
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
266
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
267
+ logger.debug(
268
+ f"quantize_per_tensor.default is inserted after {node.name}."
269
+ )
270
+ else:
271
+ raise NotYetSupportedError("Unsupported dtype")
272
+
273
+ elif node.target == torch.ops.aten.permute.default:
274
+ per_args = PermuteArgs(*node.args, **node.kwargs)
275
+ inp = per_args.input
276
+
277
+ if QPARAM_KEY not in inp.meta:
278
+ continue
279
+
280
+ if QPARAM_KEY not in node.meta:
281
+ continue
282
+
283
+ if qparam_dtype(inp) == qparam_dtype(node):
284
+ continue
285
+
286
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
287
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
288
+ # permute Op to reduce tensor size ealier
289
+ quantize = _insert_quantize_op_before(node, inp)
290
+
291
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
292
+ logger.debug(
293
+ f"quantize_per_tensor.default is inserted before {node.name}."
294
+ )
295
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
296
+ quantize = _insert_quantize_op_after(node)
297
+
298
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
299
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
300
+ logger.debug(
301
+ f"quantize_per_tensor.default is inserted after {node.name}."
302
+ )
303
+ else:
304
+ raise NotYetSupportedError("Unsupported dtype")
305
+ elif node.target == torch.ops.aten.reshape.default:
306
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
307
+ inp = reshape_args.input
308
+
309
+ if QPARAM_KEY not in inp.meta:
310
+ continue
311
+
312
+ if QPARAM_KEY not in node.meta:
313
+ continue
314
+
315
+ if qparam_dtype(inp) == qparam_dtype(node):
316
+ continue
317
+
318
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
319
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
320
+ # reshape Op to reduce tensor size ealier
321
+ quantize = _insert_quantize_op_before(node, inp)
322
+
323
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
324
+ logger.debug(
325
+ f"quantize_per_tensor.default is inserted before {node.name}."
326
+ )
327
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
328
+ quantize = _insert_quantize_op_after(node)
329
+
330
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
331
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
332
+ logger.debug(
333
+ f"quantize_per_tensor.default is inserted after {node.name}."
334
+ )
335
+ else:
336
+ raise NotYetSupportedError("Unsupported dtype")
337
+
338
+ # TODO Support more ops.
339
+
340
+ graph.eliminate_dead_code()
341
+ graph.lint()
342
+ graph_module.recompile()
343
+
344
+ # Run only once.
345
+ return PassResult(False)