tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,235 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class LowerToResizeNearestNeighbor(PassBase):
34
+ """
35
+ This pass lowers `aten.index` and `aten.upsample_nearest2d.vec` to `circle_custom.resize_nearest_neighbor` when it is possible.
36
+
37
+ Until torch 2.7, `torch.nn.functional.interpolate` is converted to `aten.index` op.
38
+ [BEFORE PASS]
39
+ input - aten.index - output
40
+
41
+ [AFTER PASS]
42
+ input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
43
+
44
+ Since torch 2.8, `torch.nn.functional.interpolate` is converted to aten.upsample_nearest2d.vec` op.
45
+ [BEFORE PASS]
46
+ input - aten.upsample_nearest2d.vec - output
47
+
48
+ [AFTER PASS]
49
+ input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
50
+ """
51
+
52
+ def __init__(self):
53
+ super().__init__()
54
+
55
+ def convert_index_to_resize_nearest_neighbor(
56
+ self, exported_program, node
57
+ ) -> Optional[torch.fx.Node]:
58
+ graph_module = exported_program.graph_module
59
+ graph = graph_module.graph
60
+
61
+ args = IndexArgs(*node.args, **node.kwargs)
62
+ input_tensor = args.input
63
+ indices = args.indices
64
+
65
+ # Only support 4-D tensor
66
+ if len(indices) != 4:
67
+ return None
68
+ # indices = [None, None, H index, W index]
69
+ N, C, H, W = indices
70
+ if N != None or C != None:
71
+ return None
72
+ if not isinstance(H, torch.fx.Node):
73
+ return None
74
+ if not isinstance(W, torch.fx.Node):
75
+ return None
76
+ constants_dict = exported_program.constants
77
+ if (H.name not in constants_dict) or (W.name not in constants_dict):
78
+ return None
79
+ H_index, W_index = constants_dict[H.name], constants_dict[W.name]
80
+ input_tensor_shape = extract_shape(input_tensor)
81
+ input_tensor_H, input_tensor_W = (
82
+ input_tensor_shape[2],
83
+ input_tensor_shape[3],
84
+ )
85
+ if H_index.size()[0] % input_tensor_H != 0:
86
+ return None
87
+ scale_factor = int(H_index.size()[0] / input_tensor_H)
88
+ # H and W should be resized with same ratio.
89
+ if scale_factor != W_index.size()[0] / input_tensor_W:
90
+ return None
91
+ expected_H_index = []
92
+ expected_W_index = []
93
+ # Please refer to above `_prop_tensor_constant1` constant in the example.
94
+ for i in range(input_tensor_H):
95
+ expected_H_index += [[i]] * scale_factor
96
+ # Please refer to above `_prop_tensor_constant0` constant in the example.
97
+ for i in range(input_tensor_W):
98
+ expected_W_index += [i] * scale_factor
99
+ if not torch.all(
100
+ torch.eq(H_index, torch.tensor(expected_H_index))
101
+ ) or not torch.all(torch.eq(W_index, torch.tensor(expected_W_index))):
102
+ return None
103
+ expected_shape = [
104
+ input_tensor_shape[0],
105
+ input_tensor_shape[1],
106
+ len(expected_H_index),
107
+ len(expected_W_index),
108
+ ]
109
+ assert expected_shape == list(extract_shape(node))
110
+
111
+ with graph.inserting_before(node):
112
+ nchw_to_nhwc = create_node(
113
+ graph,
114
+ torch.ops.aten.permute.default,
115
+ args=(input_tensor, [0, 2, 3, 1]),
116
+ origin=input_tensor,
117
+ )
118
+ resize_nearest_neighbor = create_node(
119
+ graph,
120
+ torch.ops.circle_custom.resize_nearest_neighbor,
121
+ args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
122
+ origin=node,
123
+ )
124
+ nhwc_to_nchw = create_node(
125
+ graph,
126
+ torch.ops.aten.permute.default,
127
+ args=(resize_nearest_neighbor, [0, 3, 1, 2]),
128
+ )
129
+ node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
130
+
131
+ return resize_nearest_neighbor
132
+
133
+ def convert_upsample_nearest2d_to_resize_nearest_neighbor(
134
+ self, exported_program, node
135
+ ) -> Optional[torch.fx.Node]:
136
+ graph_module = exported_program.graph_module
137
+ graph = graph_module.graph
138
+
139
+ args = UpsampleNearest2DVecArgs(*node.args, **node.kwargs)
140
+ input_tensor = args.input
141
+ output_size = args.output_size
142
+ scale_factors = args.scale_factors
143
+
144
+ input_tensor_shape = extract_shape(input_tensor)
145
+ input_tensor_H, input_tensor_W = (
146
+ input_tensor_shape[2],
147
+ input_tensor_shape[3],
148
+ )
149
+
150
+ if output_size is not None:
151
+ raise NotYetSupportedError("output_size is not supported yet")
152
+
153
+ if scale_factors is None:
154
+ raise NotYetSupportedError("scale_factors is None")
155
+ # TODO Support output_size case. Currently only scale_factors case is supported.
156
+
157
+ assert (
158
+ isinstance(scale_factors[0], float)
159
+ and isinstance(scale_factors[1], float)
160
+ and scale_factors[0] > 0
161
+ and scale_factors[1] > 0
162
+ )
163
+
164
+ def close_enough(x, y, epsilon=1e-10):
165
+ return abs(x - y) < epsilon
166
+
167
+ expected_H = int(input_tensor_H * scale_factors[0])
168
+ if not close_enough(expected_H, input_tensor_H * scale_factors[0]):
169
+ raise NotYetSupportedError(
170
+ f"Cannot support input_tensor_H ({input_tensor_H}) with scaling factor ({scale_factors[0]})"
171
+ )
172
+
173
+ expected_W = int(input_tensor_W * scale_factors[1])
174
+ if not close_enough(expected_W, input_tensor_W * scale_factors[1]):
175
+ raise NotYetSupportedError(
176
+ f"Cannot support input_tensor_W ({input_tensor_W}) with scaling factor ({scale_factors[1]})"
177
+ )
178
+
179
+ with graph.inserting_before(node):
180
+ nchw_to_nhwc = create_node(
181
+ graph,
182
+ torch.ops.aten.permute.default,
183
+ args=(input_tensor, [0, 2, 3, 1]),
184
+ origin=input_tensor,
185
+ )
186
+ resize_nearest_neighbor = create_node(
187
+ graph,
188
+ torch.ops.circle_custom.resize_nearest_neighbor,
189
+ args=(nchw_to_nhwc, [expected_H, expected_W]),
190
+ origin=node,
191
+ )
192
+ nhwc_to_nchw = create_node(
193
+ graph,
194
+ torch.ops.aten.permute.default,
195
+ args=(resize_nearest_neighbor, [0, 3, 1, 2]),
196
+ )
197
+ node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
198
+ return resize_nearest_neighbor
199
+
200
+ def call(self, exported_program: ExportedProgram) -> PassResult:
201
+ logger = logging.getLogger(__name__)
202
+
203
+ modified = False
204
+ graph_module = exported_program.graph_module
205
+ graph = graph_module.graph
206
+ for node in graph.nodes:
207
+ if not is_target_node(
208
+ node,
209
+ [torch.ops.aten.index.Tensor, torch.ops.aten.upsample_nearest2d.vec],
210
+ ):
211
+ continue
212
+
213
+ resize_nearest_neighbor = None
214
+ if node.target == torch.ops.aten.index.Tensor:
215
+ resize_nearest_neighbor = self.convert_index_to_resize_nearest_neighbor(
216
+ exported_program, node
217
+ )
218
+ elif node.target == torch.ops.aten.upsample_nearest2d.vec:
219
+ resize_nearest_neighbor = (
220
+ self.convert_upsample_nearest2d_to_resize_nearest_neighbor(
221
+ exported_program, node
222
+ )
223
+ )
224
+
225
+ if resize_nearest_neighbor:
226
+ modified = True
227
+ logger.debug(
228
+ f"{node.name} is replaced with {resize_nearest_neighbor.name} operator"
229
+ )
230
+
231
+ graph.eliminate_dead_code()
232
+ graph.lint()
233
+ graph_module.recompile()
234
+
235
+ return PassResult(modified)
@@ -0,0 +1,230 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._export.utils import (
21
+ get_buffer,
22
+ get_lifted_tensor_constant,
23
+ get_param,
24
+ is_buffer,
25
+ is_lifted_tensor_constant,
26
+ is_param,
27
+ )
28
+ from torch.export import ExportedProgram
29
+
30
+ from tico.passes import ops
31
+ from tico.serialize.circle_graph import extract_shape
32
+ from tico.utils import logging
33
+ from tico.utils.graph import create_node, is_single_value_tensor
34
+ from tico.utils.passes import PassBase, PassResult
35
+ from tico.utils.trace_decorators import trace_const_diff_on_pass
36
+ from tico.utils.utils import is_target_node
37
+ from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
38
+
39
+
40
+ def passes():
41
+ """
42
+ This pass lowers aten.ops.select/selct_copy.int to aten.ops.slice.
43
+ We support only when it is index in args, which is a constant tensor.
44
+ Since the index in node'args isn't constant tensor, we can't support converting the below op list yet.
45
+
46
+ TODO Support below with const indices
47
+ - torch.ops.aten.embedding.default
48
+ - torch.ops.aten.index.Tensor
49
+ """
50
+ return [
51
+ LowerSelectCopyToSlice(),
52
+ LowerIndexSelectToSlice(),
53
+ ]
54
+
55
+
56
+ @trace_const_diff_on_pass
57
+ class LowerSelectCopyToSlice(PassBase):
58
+ """
59
+ [before]
60
+ input
61
+ |
62
+ select (tensor, dim, *index)
63
+ |
64
+ output
65
+
66
+ [after]
67
+
68
+ input
69
+ |
70
+ slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
71
+ |
72
+ reshape (input=slice_copy, size=select_shape)
73
+ |
74
+ output
75
+ """
76
+
77
+ def __init__(self):
78
+ super().__init__()
79
+
80
+ def call(self, exported_program: ExportedProgram) -> PassResult:
81
+ logger = logging.getLogger(__name__)
82
+
83
+ graph_module = exported_program.graph_module
84
+ graph = graph_module.graph
85
+ modified = False
86
+ for node in graph.nodes:
87
+ if not is_target_node(node, ops.aten.select):
88
+ continue
89
+
90
+ args = SelectCopyIntArgs(*node.args, **node.kwargs)
91
+ input = args.input
92
+ dim = args.dim
93
+ index = args.index
94
+
95
+ input_shape = extract_shape(input)
96
+ if dim < 0:
97
+ dim = dim % len(input_shape)
98
+
99
+ start = index
100
+ end = index + 1
101
+ step = 1
102
+ slice_copy_args = (input, dim, start, end, step)
103
+
104
+ with graph.inserting_after(node):
105
+ # slice
106
+ slice_node = create_node(
107
+ graph,
108
+ torch.ops.aten.slice.Tensor,
109
+ args=slice_copy_args,
110
+ origin=node,
111
+ )
112
+ node_shape = extract_shape(node)
113
+ with graph.inserting_after(slice_node):
114
+ # reshape
115
+ reshape_args = (slice_node, list(node_shape))
116
+ reshape_node = create_node(
117
+ graph,
118
+ torch.ops.aten.reshape.default,
119
+ args=reshape_args,
120
+ )
121
+ node.replace_all_uses_with(reshape_node, propagate_meta=True)
122
+
123
+ modified = True
124
+ logger.debug(
125
+ f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
126
+ )
127
+
128
+ graph.eliminate_dead_code()
129
+ graph.lint()
130
+ graph_module.recompile()
131
+
132
+ return PassResult(modified)
133
+
134
+
135
+ @trace_const_diff_on_pass
136
+ class LowerIndexSelectToSlice(PassBase):
137
+ """
138
+
139
+ [before]
140
+ input
141
+ |
142
+ index_select.default (tensor, dim, *index)
143
+ |
144
+ output
145
+
146
+ [after]
147
+
148
+ input
149
+ |
150
+ slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
151
+ |
152
+ reshape (input=slice_copy, size=select_shape)
153
+ |
154
+ output
155
+ """
156
+
157
+ def __init__(self):
158
+ super().__init__()
159
+
160
+ def call(self, exported_program: ExportedProgram) -> PassResult:
161
+ logger = logging.getLogger(__name__)
162
+
163
+ graph_module = exported_program.graph_module
164
+ graph = graph_module.graph
165
+ modified = False
166
+ for node in graph.nodes:
167
+ if not is_target_node(node, ops.aten.index_select):
168
+ continue
169
+
170
+ args = IndexSelectArgs(*node.args, **node.kwargs)
171
+ input = args.input
172
+ dim = args.dim
173
+ index = args.index
174
+
175
+ input_shape = extract_shape(input)
176
+ if dim < 0:
177
+ dim = dim % len(input_shape)
178
+
179
+ if isinstance(index, torch.fx.Node):
180
+ if is_lifted_tensor_constant(exported_program, index):
181
+ index = get_lifted_tensor_constant(exported_program, index) # type: ignore[assignment]
182
+ elif is_param(exported_program, index):
183
+ index = get_param(exported_program, index) # type: ignore[assignment]
184
+ elif is_buffer(exported_program, index):
185
+ index = get_buffer(exported_program, index) # type: ignore[assignment]
186
+ else:
187
+ continue
188
+
189
+ if not isinstance(index, torch.Tensor):
190
+ continue
191
+
192
+ if not is_single_value_tensor(index):
193
+ # need to be lowered by LowerIndexSelect pass
194
+ continue
195
+ index_int = index.item() # convert scalar tensor to int
196
+
197
+ start = index_int
198
+ end = index_int + 1
199
+ step = 1
200
+ slice_copy_args = (input, dim, start, end, step)
201
+
202
+ with graph.inserting_after(node):
203
+ # slice
204
+ slice_node = create_node(
205
+ graph,
206
+ torch.ops.aten.slice.Tensor,
207
+ args=slice_copy_args,
208
+ origin=node,
209
+ )
210
+ node_shape = extract_shape(node)
211
+ with graph.inserting_after(slice_node):
212
+ # reshape
213
+ reshape_args = (slice_node, list(node_shape))
214
+ reshape_node = create_node(
215
+ graph,
216
+ torch.ops.aten.reshape.default,
217
+ args=reshape_args,
218
+ )
219
+ node.replace_all_uses_with(reshape_node, propagate_meta=True)
220
+
221
+ modified = True
222
+ logger.debug(
223
+ f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
224
+ )
225
+
226
+ graph.eliminate_dead_code()
227
+ graph.lint()
228
+ graph_module.recompile()
229
+
230
+ return PassResult(modified)
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.passes import ops
18
+ from tico.utils import logging
19
+ from tico.utils.passes import PassBase, PassResult
20
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
+ from tico.utils.utils import is_target_node
22
+ from tico.utils.validate_args_kwargs import CatArgs
23
+
24
+
25
+ @trace_graph_diff_on_pass
26
+ class MergeConsecutiveCat(PassBase):
27
+ """
28
+ This pass merges consecutive `aten.cat` operators when they can be merged into single operator.
29
+ """
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ def call(self, exported_program: ExportedProgram) -> PassResult:
35
+ logger = logging.getLogger(__name__)
36
+
37
+ graph_module = exported_program.graph_module
38
+ graph = graph_module.graph
39
+ modified = False
40
+ for cat in graph.nodes:
41
+ if not is_target_node(cat, ops.aten.cat):
42
+ continue
43
+
44
+ args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
45
+ inputs = args.tensors
46
+ dim = args.dim
47
+
48
+ new_inputs = []
49
+ for prev_cat in inputs:
50
+ new_inputs.append(prev_cat)
51
+ if not prev_cat.op == "call_function":
52
+ continue
53
+
54
+ if not prev_cat.target in ops.aten.cat:
55
+ continue
56
+
57
+ prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
58
+ prev_inputs = prev_args.tensors
59
+ prev_dim = prev_args.dim
60
+
61
+ if not prev_dim == dim:
62
+ continue
63
+
64
+ new_inputs.pop()
65
+ for prev_input in prev_inputs:
66
+ new_inputs.append(prev_input)
67
+
68
+ if len(new_inputs) > len(inputs):
69
+ cat.args = (new_inputs, dim)
70
+
71
+ modified = True
72
+ logger.debug(
73
+ f"Consecutive cat nodes before {cat.name} are merged into {cat.name}"
74
+ )
75
+
76
+ graph.eliminate_dead_code()
77
+ graph.lint()
78
+ graph_module.recompile()
79
+
80
+ return PassResult(modified)
tico/passes/ops.py ADDED
@@ -0,0 +1,78 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ """
19
+ This module contains Op lists used for finding the target Ops in passes.
20
+ The module is introduced to reduce duplicate codes.
21
+ It should be guaranteed that Ops in the same list have the same input/output signature.
22
+ """
23
+
24
+
25
+ class AtenOps:
26
+ def __init__(self):
27
+ # In alphabetical order
28
+ self.add = [torch.ops.aten.add.Tensor]
29
+ self.alias = [torch.ops.aten.alias.default, torch.ops.aten.alias_copy.default]
30
+ self.cat = [torch.ops.aten.cat.default]
31
+ self.clamp = [torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]
32
+ self.clone = [torch.ops.aten.clone.default]
33
+ self.conv2d = [
34
+ torch.ops.aten.conv2d.default,
35
+ torch.ops.aten.conv2d.padding,
36
+ ]
37
+ self.conv1d = [
38
+ torch.ops.aten.conv1d.default,
39
+ torch.ops.aten.conv1d.padding,
40
+ ]
41
+ self.detach = [
42
+ torch.ops.aten.detach_.default,
43
+ torch.ops.aten.detach.default,
44
+ ]
45
+ self.expand = [
46
+ torch.ops.aten.expand.default,
47
+ torch.ops.aten.expand_copy.default,
48
+ ]
49
+ self.index_select = [torch.ops.aten.index_select.default]
50
+ self.mean = [torch.ops.aten.mean.dim]
51
+ self.mul_scalar = [torch.ops.aten.mul.Scalar]
52
+ self.mul_tensor = [torch.ops.aten.mul.Tensor]
53
+ self.permute = [torch.ops.aten.permute.default]
54
+ self.reshape = [torch.ops.aten.reshape.default]
55
+ self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
56
+ self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
57
+ self.softmax = [
58
+ torch.ops.aten._softmax.default,
59
+ torch.ops.aten._safe_softmax.default,
60
+ ]
61
+ self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
62
+ self.to_copy = [
63
+ torch.ops.aten._to_copy.default,
64
+ torch.ops.aten.to.dtype,
65
+ torch.ops.aten.to.dtype_layout,
66
+ ]
67
+ self.unsqueeze = [
68
+ torch.ops.aten.unsqueeze.default,
69
+ torch.ops.aten.unsqueeze_copy.default,
70
+ ]
71
+ self.view = [
72
+ torch.ops.aten.view,
73
+ torch.ops.aten.view.default,
74
+ torch.ops.aten.view_copy.default,
75
+ ]
76
+
77
+
78
+ aten = AtenOps()
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+
28
+
29
+ @trace_graph_diff_on_pass
30
+ class RemoveNop(PassBase):
31
+ """
32
+ Let's remove noops by propagation.
33
+ """
34
+
35
+ target_ops = (
36
+ [
37
+ torch.ops.prims.view_of.default,
38
+ ]
39
+ + ops.aten.alias
40
+ + ops.aten.clone
41
+ + ops.aten.detach
42
+ + [torch.ops.aten.lift_fresh_copy.default]
43
+ )
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ def call(self, exported_program: ExportedProgram) -> PassResult:
49
+ logger = logging.getLogger(__name__)
50
+
51
+ graph_module = exported_program.graph_module
52
+ graph = graph_module.graph
53
+ modified = False
54
+ for node in graph.nodes:
55
+ if not is_target_node(node, RemoveNop.target_ops):
56
+ continue
57
+
58
+ # TODO Consider memory format
59
+ if node.target in ops.aten.clone and "memory_format" in node.kwargs:
60
+ if node.kwargs["memory_format"] not in [
61
+ torch.preserve_format,
62
+ # Converting non-contiguous layout to contiguous only updates
63
+ # strides of tensor. This is not visible on circle, so we can
64
+ # safely ignore this operation.
65
+ torch.contiguous_format,
66
+ ]:
67
+ continue
68
+
69
+ assert len(node.args) == 1
70
+
71
+ src = node.args[0]
72
+ assert isinstance(src, torch.fx.Node)
73
+
74
+ with graph.inserting_after(node):
75
+ node.replace_all_uses_with(src, propagate_meta=False)
76
+
77
+ modified = True
78
+ logger.debug(f"{node.name} is replaced with {src}")
79
+
80
+ graph.eliminate_dead_code()
81
+ graph.lint()
82
+ graph_module.recompile()
83
+
84
+ return PassResult(modified)