tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,180 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
26
+
27
+
28
+ class Converter: # type: ignore[empty-body]
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ def match(self, node) -> bool: # type: ignore[empty-body]
33
+ return False
34
+
35
+ def convert(self, exported_program, node) -> None: # type: ignore[empty-body]
36
+ pass
37
+
38
+
39
+ class ConvertHardTanhToReLU6(Converter):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ def match(self, node) -> bool:
44
+ if node.target == torch.ops.aten.hardtanh.default:
45
+ args = HardTanhArgs(*node.args, **node.kwargs)
46
+ min_val = args.min_val
47
+ max_val = args.max_val
48
+
49
+ # NOTE: int and float are both covered by pytorch implicit type conversion
50
+ return min_val == 0.0 and max_val == 6.0
51
+
52
+ return False
53
+
54
+ def convert(self, exported_program, node):
55
+ graph_module = exported_program.graph_module
56
+ graph = graph_module.graph
57
+ args = HardTanhArgs(*node.args, **node.kwargs)
58
+ input = args.input
59
+
60
+ with graph.inserting_after(node):
61
+ relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
62
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
63
+
64
+
65
+ class ConvertClampToReLU6(Converter):
66
+ def __init__(self):
67
+ super().__init__()
68
+
69
+ def match(self, node) -> bool:
70
+ if node.target == torch.ops.aten.clamp.default:
71
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
72
+ min_val = args.min
73
+ max_val = args.max
74
+
75
+ # NOTE: int and float are both covered by pytorch implicit type conversion
76
+ return min_val == 0 and max_val == 6
77
+
78
+ return False
79
+
80
+ def convert(self, exported_program, node):
81
+ graph_module = exported_program.graph_module
82
+ graph = graph_module.graph
83
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
84
+ input = args.input
85
+
86
+ with graph.inserting_after(node):
87
+ relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
88
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
89
+
90
+
91
+ class ConvertDoubleClampsToReLU6(Converter):
92
+ def __init__(self):
93
+ super().__init__()
94
+
95
+ def match(self, node) -> bool:
96
+ """
97
+ This pass matches the pattern of two clamps where it equals to clamp which has a min value of 0 and a max value of 6.
98
+
99
+ (equivalent)
100
+ input input
101
+ | |
102
+ node_prev (min, max) node (0, 6)
103
+ | |
104
+ node (min', max') |
105
+ | |
106
+ output output
107
+
108
+ *where max(min, min') == 0 and min(max, max') == 6 so that it equivalents to clamp(input, 0, 6)
109
+
110
+ TODO Make this step more generic. For now we only support the case above.
111
+ """
112
+ if not node.target == torch.ops.aten.clamp.default:
113
+ return False
114
+
115
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
116
+ node_prev = args.input
117
+ min_val = args.min if args.min is not None else float("-inf")
118
+ max_val = args.max if args.max is not None else float("inf")
119
+
120
+ if not node_prev.target == torch.ops.aten.clamp.default:
121
+ return False
122
+
123
+ prev_args = ClampArgs(*node_prev.args, **node_prev.kwargs) # type: ignore[arg-type]
124
+ min_val_prev = prev_args.min if prev_args.min is not None else float("-inf")
125
+ max_val_prev = prev_args.max if prev_args.max is not None else float("inf")
126
+
127
+ # NOTE: int and float are both covered by pytorch implicit type conversion
128
+ if max(min_val, min_val_prev) == 0 and min(max_val, max_val_prev) == 6:
129
+ return True
130
+
131
+ return False
132
+
133
+ def convert(self, exported_program, node):
134
+ graph_module = exported_program.graph_module
135
+ graph = graph_module.graph
136
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
137
+
138
+ prev_node = args.input
139
+ prev_args = ClampArgs(*prev_node.args, **prev_node.kwargs) # type: ignore[arg-type]
140
+ input = prev_args.input
141
+
142
+ with graph.inserting_after(node):
143
+ relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
144
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
145
+
146
+
147
+ @trace_graph_diff_on_pass
148
+ class ConvertToReLU6(PassBase):
149
+ def __init__(self):
150
+ super().__init__()
151
+ self.converters: List[Converter] = [
152
+ ConvertHardTanhToReLU6(),
153
+ ConvertClampToReLU6(),
154
+ ConvertDoubleClampsToReLU6(),
155
+ ]
156
+
157
+ def call(self, exported_program: ExportedProgram) -> PassResult:
158
+ logger = logging.getLogger(__name__)
159
+
160
+ graph_module = exported_program.graph_module
161
+ graph = graph_module.graph
162
+ modified = False
163
+ for node in graph.nodes:
164
+ if not node.op == "call_function":
165
+ continue
166
+
167
+ for converter in self.converters:
168
+ if not converter.match(node):
169
+ continue
170
+
171
+ converter.convert(exported_program, node)
172
+ modified = True
173
+ logger.debug(f"{node.name} is replaced with ReLU6 operator")
174
+ break
175
+
176
+ graph.eliminate_dead_code()
177
+ graph.lint()
178
+ graph_module.recompile()
179
+
180
+ return PassResult(modified)
@@ -0,0 +1,127 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.graph import add_placeholder
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import set_new_meta_val
28
+ from tico.utils.validate_args_kwargs import AddmmArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class DecomposeAddmm(PassBase):
33
+ """
34
+ Let's decompose addmm to add + mul + matmul.
35
+
36
+ [BEFORE]
37
+
38
+ input mat1 mat2 beta alpha
39
+ | | | | |
40
+ --------------addmm--------------
41
+ |
42
+ out
43
+
44
+ [AFTER]
45
+
46
+ input beta mat1 mat2 alpha
47
+ | | | | |
48
+ ---mul--- ---mm---- |
49
+ | | |
50
+ | -----mul-----
51
+ | |
52
+ ---------add----------
53
+ |
54
+ out
55
+
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def call(self, exported_program: ExportedProgram) -> PassResult:
62
+ logger = logging.getLogger(__name__)
63
+
64
+ gm = exported_program.graph_module
65
+ graph: torch.fx.Graph = gm.graph
66
+ modified = False
67
+
68
+ for node in graph.nodes:
69
+ if node.op != "call_function":
70
+ continue
71
+
72
+ if node.target in [
73
+ torch.ops.aten.addmm.default,
74
+ ]:
75
+ args = AddmmArgs(*node.args, **node.kwargs)
76
+ input = args.input
77
+ mat1 = args.mat1
78
+ mat2 = args.mat2
79
+ beta = args.beta
80
+ alpha = args.alpha
81
+
82
+ with graph.inserting_before(node):
83
+ # out = beta * input + alpha * (mat1 @ mat2)
84
+ matmul = graph.call_function(
85
+ torch.ops.aten.mm.default, (mat1, mat2)
86
+ )
87
+ set_new_meta_val(matmul)
88
+
89
+ if beta == 1:
90
+ bias: torch.fx.Node | torch.Tensor = input
91
+ elif beta == 0:
92
+ bias = add_placeholder(
93
+ exported_program,
94
+ torch.zeros(extract_shape(input)),
95
+ f"{node.name}_beta_zeros",
96
+ )
97
+ else:
98
+ bias = graph.call_function(
99
+ torch.ops.aten.mul.Tensor, (input, beta)
100
+ )
101
+
102
+ if alpha == 1:
103
+ scaled_matmul: torch.fx.Node | torch.Tensor = matmul
104
+ elif alpha == 0:
105
+ scaled_matmul = add_placeholder(
106
+ exported_program,
107
+ torch.zeros(extract_shape(matmul)),
108
+ f"{node.name}_alpha_zeros",
109
+ )
110
+ else:
111
+ scaled_matmul = graph.call_function(
112
+ torch.ops.aten.mul.Tensor, (matmul, alpha)
113
+ )
114
+
115
+ result = graph.call_function(
116
+ torch.ops.aten.add.Tensor, (bias, scaled_matmul)
117
+ )
118
+
119
+ node.replace_all_uses_with(result, propagate_meta=True)
120
+
121
+ modified = True
122
+
123
+ gm.graph.eliminate_dead_code()
124
+ gm.graph.lint()
125
+ gm.recompile()
126
+
127
+ return PassResult(modified)
@@ -0,0 +1,198 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import (
26
+ add_placeholder,
27
+ get_first_user_input,
28
+ get_torch_buffer_value,
29
+ get_torch_param_value,
30
+ is_torch_buffer,
31
+ is_torch_param,
32
+ )
33
+ from tico.utils.passes import PassBase, PassResult
34
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
35
+ from tico.utils.utils import fill_meta_val
36
+ from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
37
+
38
+
39
+ def insert_node(graph: torch.fx.Graph, operation, args):
40
+ new_node = graph.call_function(operation, args)
41
+
42
+ return new_node
43
+
44
+
45
+ @trace_graph_diff_on_pass
46
+ class DecomposeBatchNorm(PassBase):
47
+ """
48
+ [BatchNorm]
49
+
50
+ The op can be decomposed to a single aten.mul and a single aten.add because mean and
51
+ var are fixed during evaluation.
52
+
53
+ W = (weight / sqrt(var + eps))
54
+ B = bias - (mean * weight) / sqrt(var + eps)
55
+ Y = X * W + B
56
+
57
+ [before]
58
+
59
+ input (tensor, weight, bias, running_mean, running_var, momentum, eps)
60
+ |
61
+ BatchNorm
62
+ |
63
+ output
64
+
65
+ [after]
66
+
67
+ input
68
+ (tensor)
69
+ | W
70
+ | /
71
+ mul
72
+ | B
73
+ | /
74
+ add
75
+ |
76
+ output
77
+ """
78
+
79
+ def __init__(self):
80
+ super().__init__()
81
+
82
+ def call(self, exported_program: ExportedProgram) -> PassResult:
83
+ logger = logging.getLogger(__name__)
84
+
85
+ gm = exported_program.graph_module
86
+ graph: torch.fx.Graph = gm.graph
87
+ modified = False
88
+
89
+ for node in graph.nodes:
90
+ if node.op != "call_function":
91
+ continue
92
+
93
+ if node.target in [
94
+ torch.ops.aten._native_batch_norm_legit_no_training.default,
95
+ ]:
96
+ args = NativeBatchNormLegitNoTrainingArgs(*node.args)
97
+ input_ = args.input
98
+ weight = args.weight
99
+ bias = args.bias
100
+ running_mean = args.running_mean
101
+ running_var = args.running_var
102
+ eps = args.eps
103
+
104
+ if not running_mean:
105
+ raise NotYetSupportedError(
106
+ f"running_mean=None is not supported yet"
107
+ )
108
+ if not running_var:
109
+ raise NotYetSupportedError(f"running_var=None is not supported yet")
110
+
111
+ """
112
+ Only support the cases generated from torch.nn.BatchNorm2d module,
113
+ for which, let's checks if weight and bias are parameters and
114
+ running_mean and running_var are buffers.
115
+ """
116
+ if weight and not is_torch_param(weight, exported_program):
117
+ continue
118
+ if bias and not is_torch_param(bias, exported_program):
119
+ continue
120
+ if not is_torch_buffer(running_mean, exported_program):
121
+ continue
122
+ if not is_torch_buffer(running_var, exported_program):
123
+ continue
124
+
125
+ input_shape = extract_shape(input_)
126
+ assert len(input_shape) == 4
127
+ C = input_shape[1]
128
+
129
+ weight_value = (
130
+ get_torch_param_value(weight, exported_program)
131
+ if weight
132
+ else torch.tensor([1] * C)
133
+ )
134
+ bias_value = (
135
+ get_torch_param_value(bias, exported_program)
136
+ if bias
137
+ else torch.tensor([0] * C)
138
+ )
139
+ mean_value = get_torch_buffer_value(running_mean, exported_program)
140
+ var_value = get_torch_buffer_value(running_var, exported_program)
141
+
142
+ assert isinstance(weight_value, torch.Tensor)
143
+ assert isinstance(bias_value, torch.Tensor)
144
+ assert isinstance(mean_value, torch.Tensor)
145
+ assert isinstance(var_value, torch.Tensor)
146
+
147
+ assert (
148
+ weight_value.shape
149
+ == bias_value.shape
150
+ == mean_value.shape
151
+ == var_value.shape
152
+ )
153
+ # Calculate constants for mul and add
154
+ mul_const = weight_value / torch.sqrt(var_value + eps)
155
+ add_const = bias_value - (mul_const * mean_value)
156
+ # N, C, H, W
157
+ assert len(mul_const) == len(add_const) == C
158
+ # reshape along with channel dimension
159
+ mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
160
+ add_const = add_const.view(1, add_const.shape[0], 1, 1)
161
+
162
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
163
+ # Therefore, insert the newly created placeholders at the start of the node list.
164
+ with exported_program.graph.inserting_before(
165
+ get_first_user_input(exported_program)
166
+ ):
167
+ mul_const_node = add_placeholder(
168
+ exported_program,
169
+ mul_const,
170
+ prefix=f"{node.name}_mul_const",
171
+ )
172
+ add_const_node = add_placeholder(
173
+ exported_program,
174
+ add_const,
175
+ prefix=f"{node.name}_add_const",
176
+ )
177
+
178
+ with gm.graph.inserting_before(node):
179
+ mul = graph.call_function(
180
+ torch.ops.aten.mul.Tensor,
181
+ args=(input_, mul_const_node),
182
+ )
183
+ add = graph.call_function(
184
+ torch.ops.aten.add.Tensor,
185
+ args=(mul, add_const_node),
186
+ )
187
+ # Not set meta for propagating replacing get_item's meta.
188
+ get_item, *_ = node.users.keys()
189
+ get_item.replace_all_uses_with(add, propagate_meta=True)
190
+
191
+ fill_meta_val(exported_program)
192
+ modified = True
193
+
194
+ gm.graph.eliminate_dead_code()
195
+ gm.graph.lint()
196
+ gm.recompile()
197
+
198
+ return PassResult(modified)
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+
21
+ # To import torch.ops.quantized_decomposed related operator
22
+ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
23
+ from torch.export import ExportedProgram
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
29
+
30
+
31
+ def get_quant_type(min: int, max: int) -> torch.dtype:
32
+ if min == 0 and max == 255:
33
+ return torch.uint8
34
+ if min == -32768 and max == 32767:
35
+ return torch.int16
36
+ if min == -32767 and max == 32767:
37
+ return torch.int16
38
+
39
+ raise RuntimeError("Not supported min/max values")
40
+
41
+
42
+ @trace_graph_diff_on_pass
43
+ class DecomposeFakeQuantize(PassBase):
44
+ """
45
+ Decompose fake quantize operator to quant/dequant operators.
46
+ Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
47
+
48
+ [Before]
49
+ def forward(self, x):
50
+ fake_quantize_per_tensor_affine = torch.ops.aten.fake_quantize_per_tensor_affine.default(tensor, scale, zero_p, quant_min, quant_max); x = None
51
+ return (fake_quantize_per_tensor_affine,)
52
+
53
+ [After]
54
+ def forward(self, x):
55
+ quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(tensor, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); x = None
56
+ dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_default = None
57
+ return (dequantize_per_tensor_default,)
58
+ """
59
+
60
+ def __init__(self):
61
+ super().__init__()
62
+
63
+ def call(self, exported_program: ExportedProgram) -> PassResult:
64
+ logger = logging.getLogger(__name__)
65
+ modified = False
66
+
67
+ gm = exported_program.graph_module
68
+ qd = torch.ops.quantized_decomposed # type: ignore[return]
69
+ for node in gm.graph.nodes:
70
+ if node.op == "call_function" and node.target in [
71
+ torch.ops.aten.fake_quantize_per_tensor_affine.default
72
+ ]:
73
+ # tensor, scale, zero_p, quant_min, quant_max
74
+ assert len(node.args) == 5
75
+ _, _, _, quant_min, quant_max = node.args
76
+
77
+ quant_kwargs = {
78
+ **node.kwargs,
79
+ **{"dtype": get_quant_type(quant_min, quant_max)},
80
+ }
81
+ with gm.graph.inserting_before(node):
82
+ quant = gm.graph.call_function(
83
+ qd.quantize_per_tensor.default,
84
+ args=node.args,
85
+ kwargs=quant_kwargs,
86
+ )
87
+ dequnt = gm.graph.call_function(
88
+ qd.dequantize_per_tensor.default,
89
+ args=(quant, *quant.args[1:]),
90
+ kwargs=quant.kwargs,
91
+ )
92
+ # Not set meta for propagating replacing node's meta.
93
+ node.replace_all_uses_with(dequnt, propagate_meta=True)
94
+ modified = True
95
+
96
+ if node.op == "call_function" and node.target in [
97
+ torch.ops.aten.fake_quantize_per_channel_affine.default
98
+ ]:
99
+ fq_args = FakeQuantizePerChannelArgs(*node.args, **node.kwargs)
100
+ quant_min = fq_args.quant_min
101
+ quant_max = fq_args.quant_max
102
+
103
+ quant_kwargs = {
104
+ **node.kwargs,
105
+ **{"dtype": get_quant_type(quant_min, quant_max)},
106
+ }
107
+ with gm.graph.inserting_before(node):
108
+ quant = gm.graph.call_function(
109
+ qd.quantize_per_channel.default,
110
+ args=node.args,
111
+ kwargs=quant_kwargs,
112
+ )
113
+ dequnt = gm.graph.call_function(
114
+ qd.dequantize_per_channel.default,
115
+ args=(quant, *quant.args[1:]),
116
+ kwargs=quant.kwargs,
117
+ )
118
+ # Not set meta for propagating replacing node's meta.
119
+ node.replace_all_uses_with(dequnt, propagate_meta=True)
120
+ modified = True
121
+
122
+ gm.graph.eliminate_dead_code()
123
+ gm.graph.lint()
124
+ gm.recompile()
125
+
126
+ return PassResult(modified)