tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.graph import create_node
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
27
+
28
+
29
+ class Converter: # type: ignore[empty-body]
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def match(self, node) -> bool: # type: ignore[empty-body]
34
+ return False
35
+
36
+ def convert(self, exported_program, node) -> None: # type: ignore[empty-body]
37
+ pass
38
+
39
+
40
+ class ConvertHardTanhToReLU6(Converter):
41
+ def __init__(self):
42
+ super().__init__()
43
+
44
+ def match(self, node) -> bool:
45
+ if node.target == torch.ops.aten.hardtanh.default:
46
+ args = HardTanhArgs(*node.args, **node.kwargs)
47
+ min_val = args.min_val
48
+ max_val = args.max_val
49
+
50
+ # NOTE: int and float are both covered by pytorch implicit type conversion
51
+ return min_val == 0.0 and max_val == 6.0
52
+
53
+ return False
54
+
55
+ def convert(self, exported_program, node):
56
+ graph_module = exported_program.graph_module
57
+ graph = graph_module.graph
58
+ args = HardTanhArgs(*node.args, **node.kwargs)
59
+ input = args.input
60
+
61
+ with graph.inserting_after(node):
62
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
63
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
64
+
65
+
66
+ class ConvertClampToReLU6(Converter):
67
+ def __init__(self):
68
+ super().__init__()
69
+
70
+ def match(self, node) -> bool:
71
+ if node.target == torch.ops.aten.clamp.default:
72
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
73
+ min_val = args.min
74
+ max_val = args.max
75
+
76
+ # NOTE: int and float are both covered by pytorch implicit type conversion
77
+ return min_val == 0 and max_val == 6
78
+
79
+ return False
80
+
81
+ def convert(self, exported_program, node):
82
+ graph_module = exported_program.graph_module
83
+ graph = graph_module.graph
84
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
85
+ input = args.input
86
+
87
+ with graph.inserting_after(node):
88
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
89
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
90
+
91
+
92
+ class ConvertDoubleClampsToReLU6(Converter):
93
+ def __init__(self):
94
+ super().__init__()
95
+
96
+ def match(self, node) -> bool:
97
+ """
98
+ This pass matches the pattern of two clamps where it equals to clamp which has a min value of 0 and a max value of 6.
99
+
100
+ (equivalent)
101
+ input input
102
+ | |
103
+ node_prev (min, max) node (0, 6)
104
+ | |
105
+ node (min', max') |
106
+ | |
107
+ output output
108
+
109
+ *where max(min, min') == 0 and min(max, max') == 6 so that it equivalents to clamp(input, 0, 6)
110
+
111
+ TODO Make this step more generic. For now we only support the case above.
112
+ """
113
+ if not node.target == torch.ops.aten.clamp.default:
114
+ return False
115
+
116
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
117
+ node_prev = args.input
118
+ min_val = args.min if args.min is not None else float("-inf")
119
+ max_val = args.max if args.max is not None else float("inf")
120
+
121
+ if not node_prev.target == torch.ops.aten.clamp.default:
122
+ return False
123
+
124
+ prev_args = ClampArgs(*node_prev.args, **node_prev.kwargs) # type: ignore[arg-type]
125
+ min_val_prev = prev_args.min if prev_args.min is not None else float("-inf")
126
+ max_val_prev = prev_args.max if prev_args.max is not None else float("inf")
127
+
128
+ # NOTE: int and float are both covered by pytorch implicit type conversion
129
+ if max(min_val, min_val_prev) == 0 and min(max_val, max_val_prev) == 6:
130
+ return True
131
+
132
+ return False
133
+
134
+ def convert(self, exported_program, node):
135
+ graph_module = exported_program.graph_module
136
+ graph = graph_module.graph
137
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
138
+
139
+ prev_node = args.input
140
+ prev_args = ClampArgs(*prev_node.args, **prev_node.kwargs) # type: ignore[arg-type]
141
+ input = prev_args.input
142
+
143
+ with graph.inserting_after(node):
144
+ relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
145
+ node.replace_all_uses_with(relu_node, propagate_meta=True)
146
+
147
+
148
+ @trace_graph_diff_on_pass
149
+ class ConvertToReLU6(PassBase):
150
+ def __init__(self):
151
+ super().__init__()
152
+ self.converters: List[Converter] = [
153
+ ConvertHardTanhToReLU6(),
154
+ ConvertClampToReLU6(),
155
+ ConvertDoubleClampsToReLU6(),
156
+ ]
157
+
158
+ def call(self, exported_program: ExportedProgram) -> PassResult:
159
+ logger = logging.getLogger(__name__)
160
+
161
+ graph_module = exported_program.graph_module
162
+ graph = graph_module.graph
163
+ modified = False
164
+ for node in graph.nodes:
165
+ if not node.op == "call_function":
166
+ continue
167
+
168
+ for converter in self.converters:
169
+ if not converter.match(node):
170
+ continue
171
+
172
+ converter.convert(exported_program, node)
173
+ modified = True
174
+ logger.debug(f"{node.name} is replaced with ReLU6 operator")
175
+ break
176
+
177
+ graph.eliminate_dead_code()
178
+ graph.lint()
179
+ graph_module.recompile()
180
+
181
+ return PassResult(modified)
@@ -0,0 +1,124 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.graph import add_placeholder, create_node
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node, set_new_meta_val
28
+ from tico.utils.validate_args_kwargs import AddmmArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class DecomposeAddmm(PassBase):
33
+ """
34
+ Let's decompose addmm to add + mul + matmul.
35
+
36
+ [BEFORE]
37
+
38
+ input mat1 mat2 beta alpha
39
+ | | | | |
40
+ --------------addmm--------------
41
+ |
42
+ out
43
+
44
+ [AFTER]
45
+
46
+ input beta mat1 mat2 alpha
47
+ | | | | |
48
+ ---mul--- ---mm---- |
49
+ | | |
50
+ | -----mul-----
51
+ | |
52
+ ---------add----------
53
+ |
54
+ out
55
+
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def call(self, exported_program: ExportedProgram) -> PassResult:
62
+ logger = logging.getLogger(__name__)
63
+
64
+ gm = exported_program.graph_module
65
+ graph: torch.fx.Graph = gm.graph
66
+ modified = False
67
+
68
+ for node in graph.nodes:
69
+ if not is_target_node(node, torch.ops.aten.addmm.default):
70
+ continue
71
+
72
+ args = AddmmArgs(*node.args, **node.kwargs)
73
+ input = args.input
74
+ mat1 = args.mat1
75
+ mat2 = args.mat2
76
+ beta = args.beta
77
+ alpha = args.alpha
78
+
79
+ with graph.inserting_before(node):
80
+ # out = beta * input + alpha * (mat1 @ mat2)
81
+ matmul = create_node(
82
+ graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
83
+ )
84
+ set_new_meta_val(matmul)
85
+
86
+ if beta == 1:
87
+ bias: torch.fx.Node | torch.Tensor = input
88
+ elif beta == 0:
89
+ bias = add_placeholder(
90
+ exported_program,
91
+ torch.zeros(extract_shape(input)),
92
+ f"{node.name}_beta_zeros",
93
+ )
94
+ else:
95
+ bias = create_node(
96
+ graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
97
+ )
98
+
99
+ if alpha == 1:
100
+ scaled_matmul: torch.fx.Node | torch.Tensor = matmul
101
+ elif alpha == 0:
102
+ scaled_matmul = add_placeholder(
103
+ exported_program,
104
+ torch.zeros(extract_shape(matmul)),
105
+ f"{node.name}_alpha_zeros",
106
+ )
107
+ else:
108
+ scaled_matmul = create_node(
109
+ graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
110
+ )
111
+
112
+ result = create_node(
113
+ graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
114
+ )
115
+
116
+ node.replace_all_uses_with(result, propagate_meta=True)
117
+
118
+ modified = True
119
+
120
+ gm.graph.eliminate_dead_code()
121
+ gm.graph.lint()
122
+ gm.recompile()
123
+
124
+ return PassResult(modified)
@@ -0,0 +1,192 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import (
26
+ add_placeholder,
27
+ create_node,
28
+ get_first_user_input,
29
+ get_torch_buffer_value,
30
+ get_torch_param_value,
31
+ is_torch_buffer,
32
+ is_torch_param,
33
+ )
34
+ from tico.utils.passes import PassBase, PassResult
35
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
36
+ from tico.utils.utils import is_target_node
37
+ from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
38
+
39
+
40
+ @trace_graph_diff_on_pass
41
+ class DecomposeBatchNorm(PassBase):
42
+ """
43
+ [BatchNorm]
44
+
45
+ The op can be decomposed to a single aten.mul and a single aten.add because mean and
46
+ var are fixed during evaluation.
47
+
48
+ W = (weight / sqrt(var + eps))
49
+ B = bias - (mean * weight) / sqrt(var + eps)
50
+ Y = X * W + B
51
+
52
+ [before]
53
+
54
+ input (tensor, weight, bias, running_mean, running_var, momentum, eps)
55
+ |
56
+ BatchNorm
57
+ |
58
+ output
59
+
60
+ [after]
61
+
62
+ input
63
+ (tensor)
64
+ | W
65
+ | /
66
+ mul
67
+ | B
68
+ | /
69
+ add
70
+ |
71
+ output
72
+ """
73
+
74
+ def __init__(self):
75
+ super().__init__()
76
+
77
+ def call(self, exported_program: ExportedProgram) -> PassResult:
78
+ logger = logging.getLogger(__name__)
79
+
80
+ gm = exported_program.graph_module
81
+ graph: torch.fx.Graph = gm.graph
82
+ modified = False
83
+
84
+ for node in graph.nodes:
85
+ if not is_target_node(
86
+ node, torch.ops.aten._native_batch_norm_legit_no_training.default
87
+ ):
88
+ continue
89
+
90
+ args = NativeBatchNormLegitNoTrainingArgs(*node.args)
91
+ input_ = args.input
92
+ weight = args.weight
93
+ bias = args.bias
94
+ running_mean = args.running_mean
95
+ running_var = args.running_var
96
+ eps = args.eps
97
+
98
+ if not running_mean:
99
+ raise NotYetSupportedError(f"running_mean=None is not supported yet")
100
+ if not running_var:
101
+ raise NotYetSupportedError(f"running_var=None is not supported yet")
102
+
103
+ """
104
+ Only support the cases generated from torch.nn.BatchNorm2d module,
105
+ for which, let's checks if weight and bias are parameters and
106
+ running_mean and running_var are buffers.
107
+ """
108
+ if weight and not is_torch_param(weight, exported_program):
109
+ continue
110
+ if bias and not is_torch_param(bias, exported_program):
111
+ continue
112
+ if not is_torch_buffer(running_mean, exported_program):
113
+ continue
114
+ if not is_torch_buffer(running_var, exported_program):
115
+ continue
116
+
117
+ input_shape = extract_shape(input_)
118
+ assert len(input_shape) == 4
119
+ C = input_shape[1]
120
+
121
+ weight_value = (
122
+ get_torch_param_value(weight, exported_program)
123
+ if weight
124
+ else torch.tensor([1] * C)
125
+ )
126
+ bias_value = (
127
+ get_torch_param_value(bias, exported_program)
128
+ if bias
129
+ else torch.tensor([0] * C)
130
+ )
131
+ mean_value = get_torch_buffer_value(running_mean, exported_program)
132
+ var_value = get_torch_buffer_value(running_var, exported_program)
133
+
134
+ assert isinstance(weight_value, torch.Tensor)
135
+ assert isinstance(bias_value, torch.Tensor)
136
+ assert isinstance(mean_value, torch.Tensor)
137
+ assert isinstance(var_value, torch.Tensor)
138
+
139
+ assert (
140
+ weight_value.shape
141
+ == bias_value.shape
142
+ == mean_value.shape
143
+ == var_value.shape
144
+ )
145
+ # Calculate constants for mul and add
146
+ mul_const = weight_value / torch.sqrt(var_value + eps)
147
+ add_const = bias_value - (mul_const * mean_value)
148
+ # N, C, H, W
149
+ assert len(mul_const) == len(add_const) == C
150
+ # reshape along with channel dimension
151
+ mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
152
+ add_const = add_const.view(1, add_const.shape[0], 1, 1)
153
+
154
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
155
+ # Therefore, insert the newly created placeholders at the start of the node list.
156
+ with exported_program.graph.inserting_before(
157
+ get_first_user_input(exported_program)
158
+ ):
159
+ mul_const_node = add_placeholder(
160
+ exported_program,
161
+ mul_const,
162
+ prefix=f"{node.name}_mul_const",
163
+ )
164
+ add_const_node = add_placeholder(
165
+ exported_program,
166
+ add_const,
167
+ prefix=f"{node.name}_add_const",
168
+ )
169
+
170
+ with gm.graph.inserting_before(node):
171
+ mul = create_node(
172
+ graph,
173
+ torch.ops.aten.mul.Tensor,
174
+ args=(input_, mul_const_node),
175
+ origin=node,
176
+ )
177
+ add = create_node(
178
+ graph,
179
+ torch.ops.aten.add.Tensor,
180
+ args=(mul, add_const_node),
181
+ )
182
+ get_item, *_ = node.users.keys()
183
+ get_item.replace_all_uses_with(add, propagate_meta=True)
184
+
185
+ logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
186
+ modified = True
187
+
188
+ gm.graph.eliminate_dead_code()
189
+ gm.graph.lint()
190
+ gm.recompile()
191
+
192
+ return PassResult(modified)
@@ -0,0 +1,134 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+
21
+ # To import torch.ops.quantized_decomposed related operator
22
+ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
23
+ from torch.export import ExportedProgram
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
30
+
31
+
32
+ def get_quant_type(min: int, max: int) -> torch.dtype:
33
+ if min == 0 and max == 15:
34
+ # torch can't represent "uint4".
35
+ # Let's set torch.uint8 and infer dtype with quant_min/quant_max instead.
36
+ return torch.uint8
37
+ if min == 0 and max == 255:
38
+ return torch.uint8
39
+ if min == -32768 and max == 32767:
40
+ return torch.int16
41
+ if min == -32767 and max == 32767:
42
+ return torch.int16
43
+
44
+ raise RuntimeError(f"Not supported min/max values: {min}/{max}")
45
+
46
+
47
+ @trace_graph_diff_on_pass
48
+ class DecomposeFakeQuantize(PassBase):
49
+ """
50
+ Decompose fake quantize operator to quant/dequant operators.
51
+ Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
52
+
53
+ [Before]
54
+ def forward(self, x):
55
+ fake_quantize_per_tensor_affine = torch.ops.aten.fake_quantize_per_tensor_affine.default(tensor, scale, zero_p, quant_min, quant_max); x = None
56
+ return (fake_quantize_per_tensor_affine,)
57
+
58
+ [After]
59
+ def forward(self, x):
60
+ quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(tensor, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); x = None
61
+ dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_default = None
62
+ return (dequantize_per_tensor_default,)
63
+ """
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+
68
+ def call(self, exported_program: ExportedProgram) -> PassResult:
69
+ logger = logging.getLogger(__name__)
70
+ modified = False
71
+
72
+ gm = exported_program.graph_module
73
+ g = gm.graph
74
+ qd = torch.ops.quantized_decomposed # type: ignore[return]
75
+ for node in gm.graph.nodes:
76
+ if node.op != "call_function":
77
+ continue
78
+ if node.target in [torch.ops.aten.fake_quantize_per_tensor_affine.default]:
79
+ # tensor, scale, zero_p, quant_min, quant_max
80
+ assert len(node.args) == 5
81
+ _, _, _, quant_min, quant_max = node.args
82
+
83
+ quant_kwargs = {
84
+ **node.kwargs,
85
+ **{"dtype": get_quant_type(quant_min, quant_max)},
86
+ }
87
+ with gm.graph.inserting_before(node):
88
+ quant = create_node(
89
+ g,
90
+ qd.quantize_per_tensor.default,
91
+ args=node.args,
92
+ kwargs=quant_kwargs,
93
+ origin=node,
94
+ )
95
+ dequnt = create_node(
96
+ g,
97
+ qd.dequantize_per_tensor.default,
98
+ args=(quant, *quant.args[1:]),
99
+ kwargs=quant.kwargs,
100
+ )
101
+ node.replace_all_uses_with(dequnt, propagate_meta=True)
102
+ modified = True
103
+
104
+ if node.target in [torch.ops.aten.fake_quantize_per_channel_affine.default]:
105
+ fq_args = FakeQuantizePerChannelArgs(*node.args, **node.kwargs)
106
+ quant_min = fq_args.quant_min
107
+ quant_max = fq_args.quant_max
108
+
109
+ quant_kwargs = {
110
+ **node.kwargs,
111
+ **{"dtype": get_quant_type(quant_min, quant_max)},
112
+ }
113
+ with gm.graph.inserting_before(node):
114
+ quant = create_node(
115
+ g,
116
+ qd.quantize_per_channel.default,
117
+ args=node.args,
118
+ kwargs=quant_kwargs,
119
+ origin=node,
120
+ )
121
+ dequnt = create_node(
122
+ g,
123
+ qd.dequantize_per_channel.default,
124
+ args=(quant, *quant.args[1:]),
125
+ kwargs=quant.kwargs,
126
+ )
127
+ node.replace_all_uses_with(dequnt, propagate_meta=True)
128
+ modified = True
129
+
130
+ gm.graph.eliminate_dead_code()
131
+ gm.graph.lint()
132
+ gm.recompile()
133
+
134
+ return PassResult(modified)