tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.circle_mapping import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.errors import NotYetSupportedError
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class LowerToResizeNearestNeighbor(PassBase):
34
+ """
35
+ This pass lowers `aten.index` and `aten.upsample_nearest2d.vec` to `circle_custom.resize_nearest_neighbor` when it is possible.
36
+
37
+ Until torch 2.7, `torch.nn.functional.interpolate` is converted to `aten.index` op.
38
+
39
+ [EXAMPLE]
40
+ class InterpolateDouble(torch.nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+
44
+ def forward(self, x):
45
+ return torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
46
+
47
+ def get_example_inputs(self):
48
+ return (torch.randn(1, 2, 3, 4),)
49
+
50
+ [EXPORTED GRAPH]
51
+ [constants]
52
+ _prop_tensor_constant0 = tensor([0, 0, 1, 1, 2, 2, 3, 3]
53
+ _prop_tensor_constant1 = tensor([[0], [0], [1], [1], [2], [2]])
54
+
55
+ [graph]
56
+ %_prop_tensor_constant0 : [num_users=1] = placeholder[target=_prop_tensor_constant0]
57
+ %_prop_tensor_constant1 : [num_users=1] = placeholder[target=_prop_tensor_constant1]
58
+ %x : [num_users=1] = placeholder[target=x]
59
+ %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
60
+ %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %_prop_tensor_constant1, %_prop_tensor_constant0]), kwargs = {})
61
+ %_to_copy_3 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%index,), kwargs = {dtype: torch.float32})
62
+ return (_to_copy_3,)
63
+
64
+ [BEFORE PASS]
65
+ input - aten.index - output
66
+
67
+ [AFTER PASS]
68
+ input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
69
+
70
+ Since torch 2.8, `torch.nn.functional.interpolate` is converted to aten.upsample_nearest2d.vec` op.
71
+ """
72
+
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def convert_index_to_resize_nearest_neighbor(
77
+ self, exported_program, node
78
+ ) -> Optional[torch.fx.Node]:
79
+ graph_module = exported_program.graph_module
80
+ graph = graph_module.graph
81
+
82
+ args = IndexArgs(*node.args, **node.kwargs)
83
+ input_tensor = args.input
84
+ indices = args.indices
85
+
86
+ # Only support 4-D tensor
87
+ if len(indices) != 4:
88
+ return None
89
+ # indices = [None, None, H index, W index]
90
+ N, C, H, W = indices
91
+ if N != None or C != None:
92
+ return None
93
+ if not isinstance(H, torch.fx.Node):
94
+ return None
95
+ if not isinstance(W, torch.fx.Node):
96
+ return None
97
+ constants_dict = exported_program.constants
98
+ if (H.name not in constants_dict) or (W.name not in constants_dict):
99
+ return None
100
+ H_index, W_index = constants_dict[H.name], constants_dict[W.name]
101
+ input_tensor_shape = extract_shape(input_tensor)
102
+ input_tensor_H, input_tensor_W = (
103
+ input_tensor_shape[2],
104
+ input_tensor_shape[3],
105
+ )
106
+ if H_index.size()[0] % input_tensor_H != 0:
107
+ return None
108
+ scale_factor = int(H_index.size()[0] / input_tensor_H)
109
+ # H and W should be resized with same ratio.
110
+ if scale_factor != W_index.size()[0] / input_tensor_W:
111
+ return None
112
+ expected_H_index = []
113
+ expected_W_index = []
114
+ # Please refer to above `_prop_tensor_constant1` constant in the example.
115
+ for i in range(input_tensor_H):
116
+ expected_H_index += [[i]] * scale_factor
117
+ # Please refer to above `_prop_tensor_constant0` constant in the example.
118
+ for i in range(input_tensor_W):
119
+ expected_W_index += [i] * scale_factor
120
+ if not torch.all(
121
+ torch.eq(H_index, torch.tensor(expected_H_index))
122
+ ) or not torch.all(torch.eq(W_index, torch.tensor(expected_W_index))):
123
+ return None
124
+ expected_shape = [
125
+ input_tensor_shape[0],
126
+ input_tensor_shape[1],
127
+ len(expected_H_index),
128
+ len(expected_W_index),
129
+ ]
130
+ assert expected_shape == list(extract_shape(node))
131
+
132
+ with graph.inserting_before(node):
133
+ nchw_to_nhwc = graph.call_function(
134
+ torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
135
+ )
136
+ resize_nearest_neighbor = graph.call_function(
137
+ torch.ops.circle_custom.resize_nearest_neighbor,
138
+ args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
139
+ )
140
+ nhwc_to_nchw = graph.call_function(
141
+ torch.ops.aten.permute.default,
142
+ args=(resize_nearest_neighbor, [0, 3, 1, 2]),
143
+ )
144
+ # Not set meta for propagating replacing node's meta.
145
+ node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
146
+
147
+ return resize_nearest_neighbor
148
+
149
+ def convert_upsample_nearest2d_to_resize_nearest_neighbor(
150
+ self, exported_program, node
151
+ ) -> Optional[torch.fx.Node]:
152
+ graph_module = exported_program.graph_module
153
+ graph = graph_module.graph
154
+
155
+ args = UpsampleNearest2DVecArgs(*node.args, **node.kwargs)
156
+ input_tensor = args.input
157
+ output_size = args.output_size
158
+ scale_factors = args.scale_factors
159
+
160
+ input_tensor_shape = extract_shape(input_tensor)
161
+ input_tensor_H, input_tensor_W = (
162
+ input_tensor_shape[2],
163
+ input_tensor_shape[3],
164
+ )
165
+
166
+ if output_size is not None:
167
+ raise NotYetSupportedError("output_size is not supported yet")
168
+
169
+ if scale_factors is None:
170
+ raise NotYetSupportedError("scale_factors is None")
171
+ # TODO Support output_size case. Currently only scale_factors case is supported.
172
+
173
+ assert (
174
+ isinstance(scale_factors[0], float)
175
+ and isinstance(scale_factors[1], float)
176
+ and scale_factors[0] > 0
177
+ and scale_factors[1] > 0
178
+ )
179
+
180
+ def close_enough(x, y, epsilon=1e-10):
181
+ return abs(x - y) < epsilon
182
+
183
+ expected_H = int(input_tensor_H * scale_factors[0])
184
+ if not close_enough(expected_H, input_tensor_H * scale_factors[0]):
185
+ raise NotYetSupportedError(
186
+ f"Cannot support input_tensor_H ({input_tensor_H}) with scaling factor ({scale_factors[0]})"
187
+ )
188
+
189
+ expected_W = int(input_tensor_W * scale_factors[1])
190
+ if not close_enough(expected_W, input_tensor_W * scale_factors[1]):
191
+ raise NotYetSupportedError(
192
+ f"Cannot support input_tensor_W ({input_tensor_W}) with scaling factor ({scale_factors[1]})"
193
+ )
194
+
195
+ with graph.inserting_before(node):
196
+ nchw_to_nhwc = graph.call_function(
197
+ torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
198
+ )
199
+ resize_nearest_neighbor = graph.call_function(
200
+ torch.ops.circle_custom.resize_nearest_neighbor,
201
+ args=(nchw_to_nhwc, [expected_H, expected_W]),
202
+ )
203
+ nhwc_to_nchw = graph.call_function(
204
+ torch.ops.aten.permute.default,
205
+ args=(resize_nearest_neighbor, [0, 3, 1, 2]),
206
+ )
207
+ # Not set meta for propagating replacing node's meta.
208
+ node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
209
+ return resize_nearest_neighbor
210
+
211
+ def call(self, exported_program: ExportedProgram) -> PassResult:
212
+ logger = logging.getLogger(__name__)
213
+
214
+ modified = False
215
+ graph_module = exported_program.graph_module
216
+ graph = graph_module.graph
217
+ for node in graph.nodes:
218
+ if not node.op == "call_function":
219
+ continue
220
+
221
+ if node.target not in [
222
+ torch.ops.aten.index.Tensor,
223
+ torch.ops.aten.upsample_nearest2d.vec,
224
+ ]:
225
+ continue
226
+
227
+ resize_nearest_neighbor = None
228
+ if node.target == torch.ops.aten.index.Tensor:
229
+ resize_nearest_neighbor = self.convert_index_to_resize_nearest_neighbor(
230
+ exported_program, node
231
+ )
232
+ elif node.target == torch.ops.aten.upsample_nearest2d.vec:
233
+ resize_nearest_neighbor = (
234
+ self.convert_upsample_nearest2d_to_resize_nearest_neighbor(
235
+ exported_program, node
236
+ )
237
+ )
238
+
239
+ if resize_nearest_neighbor:
240
+ modified = True
241
+ logger.debug(
242
+ f"{node.name} is replaced with {resize_nearest_neighbor.name} operator"
243
+ )
244
+
245
+ graph.eliminate_dead_code()
246
+ graph.lint()
247
+ graph_module.recompile()
248
+
249
+ return PassResult(modified)
@@ -0,0 +1,112 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+
24
+ from tico.serialize.circle_graph import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_const_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import SelectCopyIntArgs
29
+
30
+
31
+ @trace_const_diff_on_pass
32
+ class LowerToSlice(PassBase):
33
+ """
34
+ This pass lowers aten.ops.select/selct_copy.int to aten.ops.slice.
35
+ We support only when it is index in args, which is a constant tensor.
36
+ Since the index in node'args isn't constant tensor, we can't support converting the below op list yet.
37
+ - torch.ops.aten.index_select.default
38
+ - torch.ops.aten.embedding.default
39
+ - torch.ops.aten.index.Tensor
40
+
41
+ [before]
42
+ input (tensor, dim, *index)
43
+ |
44
+ select
45
+ |
46
+ output
47
+
48
+ [after]
49
+
50
+ input (tensor, dim, *index)
51
+ |
52
+ slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
53
+ |
54
+ reshape (input=slice_copy, size=select_shape)
55
+ |
56
+ output
57
+ """
58
+
59
+ def __init__(self):
60
+ super().__init__()
61
+
62
+ def call(self, exported_program: ExportedProgram) -> PassResult:
63
+ logger = logging.getLogger(__name__)
64
+
65
+ graph_module = exported_program.graph_module
66
+ graph = graph_module.graph
67
+ modified = False
68
+ for node in graph.nodes:
69
+ if not node.op == "call_function":
70
+ continue
71
+
72
+ if not node.target in ops.aten.select:
73
+ continue
74
+
75
+ args = SelectCopyIntArgs(*node.args, **node.kwargs)
76
+ input = args.input
77
+ dim = args.dim
78
+ index = args.index
79
+
80
+ input_shape = extract_shape(input)
81
+ if dim < 0:
82
+ dim = dim % len(input_shape)
83
+
84
+ start = index
85
+ end = index + 1
86
+ step = 1
87
+ slice_copy_args = (input, dim, start, end, step)
88
+
89
+ with graph.inserting_after(node):
90
+ # slice
91
+ slice_node = graph.call_function(
92
+ torch.ops.aten.slice.Tensor, args=slice_copy_args
93
+ )
94
+ node_shape = extract_shape(node)
95
+ with graph.inserting_after(slice_node):
96
+ # reshape
97
+ reshape_args = (slice_node, list(node_shape))
98
+ reshape_node = graph.call_function(
99
+ torch.ops.aten.reshape.default, args=reshape_args
100
+ )
101
+ node.replace_all_uses_with(reshape_node, propagate_meta=False)
102
+
103
+ modified = True
104
+ logger.debug(
105
+ f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
106
+ )
107
+
108
+ graph.eliminate_dead_code()
109
+ graph.lint()
110
+ graph_module.recompile()
111
+
112
+ return PassResult(modified)
@@ -0,0 +1,82 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.passes import ops
18
+ from tico.utils import logging
19
+ from tico.utils.passes import PassBase, PassResult
20
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
+ from tico.utils.validate_args_kwargs import CatArgs
22
+
23
+
24
+ @trace_graph_diff_on_pass
25
+ class MergeConsecutiveCat(PassBase):
26
+ """
27
+ This pass merges consecutive `aten.cat` operators when they can be merged into single operator.
28
+ """
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def call(self, exported_program: ExportedProgram) -> PassResult:
34
+ logger = logging.getLogger(__name__)
35
+
36
+ graph_module = exported_program.graph_module
37
+ graph = graph_module.graph
38
+ modified = False
39
+ for cat in graph.nodes:
40
+ if not cat.op == "call_function":
41
+ continue
42
+
43
+ if not cat.target in ops.aten.cat:
44
+ continue
45
+
46
+ args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
47
+ inputs = args.tensors
48
+ dim = args.dim
49
+
50
+ new_inputs = []
51
+ for prev_cat in inputs:
52
+ new_inputs.append(prev_cat)
53
+ if not prev_cat.op == "call_function":
54
+ continue
55
+
56
+ if not prev_cat.target in ops.aten.cat:
57
+ continue
58
+
59
+ prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
60
+ prev_inputs = prev_args.tensors
61
+ prev_dim = prev_args.dim
62
+
63
+ if not prev_dim == dim:
64
+ continue
65
+
66
+ new_inputs.pop()
67
+ for prev_input in prev_inputs:
68
+ new_inputs.append(prev_input)
69
+
70
+ if len(new_inputs) > len(inputs):
71
+ cat.args = (new_inputs, dim)
72
+
73
+ modified = True
74
+ logger.debug(
75
+ f"Consecutive cat nodes before {cat.name} are merged into {cat.name}"
76
+ )
77
+
78
+ graph.eliminate_dead_code()
79
+ graph.lint()
80
+ graph_module.recompile()
81
+
82
+ return PassResult(modified)
tico/passes/ops.py ADDED
@@ -0,0 +1,75 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ """
19
+ This module contains Op lists used for finding the target Ops in passes.
20
+ The module is introduced to reduce duplicate codes.
21
+ It should be guaranteed that Ops in the same list have the same input/output signature.
22
+ """
23
+
24
+
25
+ class AtenOps:
26
+ def __init__(self):
27
+ # In alphabetical order
28
+ self.add = [torch.ops.aten.add.Tensor]
29
+ self.alias = [torch.ops.aten.alias.default, torch.ops.aten.alias_copy.default]
30
+ self.cat = [torch.ops.aten.cat.default]
31
+ self.clamp = [torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]
32
+ self.clone = [torch.ops.aten.clone.default]
33
+ self.conv2d = [
34
+ torch.ops.aten.conv2d.default,
35
+ torch.ops.aten.conv2d.padding,
36
+ ]
37
+ self.conv1d = [
38
+ torch.ops.aten.conv1d.default,
39
+ torch.ops.aten.conv1d.padding,
40
+ ]
41
+ self.detach = [
42
+ torch.ops.aten.detach_.default,
43
+ torch.ops.aten.detach.default,
44
+ ]
45
+ self.expand = [
46
+ torch.ops.aten.expand.default,
47
+ torch.ops.aten.expand_copy.default,
48
+ ]
49
+ self.index_select = [torch.ops.aten.index_select.default]
50
+ self.mean = [torch.ops.aten.mean.dim]
51
+ self.mul_scalar = [torch.ops.aten.mul.Scalar]
52
+ self.mul_tensor = [torch.ops.aten.mul.Tensor]
53
+ self.permute = [torch.ops.aten.permute.default]
54
+ self.reshape = [torch.ops.aten.reshape.default]
55
+ self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
56
+ self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
57
+ self.softmax = [torch.ops.aten._softmax.default]
58
+ self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
59
+ self.to_copy = [
60
+ torch.ops.aten._to_copy.default,
61
+ torch.ops.aten.to.dtype,
62
+ torch.ops.aten.to.dtype_layout,
63
+ ]
64
+ self.unsqueeze = [
65
+ torch.ops.aten.unsqueeze.default,
66
+ torch.ops.aten.unsqueeze_copy.default,
67
+ ]
68
+ self.view = [
69
+ torch.ops.aten.view,
70
+ torch.ops.aten.view.default,
71
+ torch.ops.aten.view_copy.default,
72
+ ]
73
+
74
+
75
+ aten = AtenOps()
@@ -0,0 +1,85 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+
27
+
28
+ @trace_graph_diff_on_pass
29
+ class RemoveNop(PassBase):
30
+ """
31
+ Let's remove noops by propagation.
32
+ """
33
+
34
+ target_ops = (
35
+ [
36
+ torch.ops.prims.view_of.default,
37
+ ]
38
+ + ops.aten.alias
39
+ + ops.aten.clone
40
+ + ops.aten.detach
41
+ + [torch.ops.aten.lift_fresh_copy.default]
42
+ )
43
+
44
+ def __init__(self):
45
+ super().__init__()
46
+
47
+ def call(self, exported_program: ExportedProgram) -> PassResult:
48
+ logger = logging.getLogger(__name__)
49
+
50
+ graph_module = exported_program.graph_module
51
+ graph = graph_module.graph
52
+ modified = False
53
+ for node in graph.nodes:
54
+ if not node.op == "call_function":
55
+ continue
56
+
57
+ if not node.target in RemoveNop.target_ops:
58
+ continue
59
+ # TODO Consider memory format
60
+ if node.target in ops.aten.clone and "memory_format" in node.kwargs:
61
+ if node.kwargs["memory_format"] not in [
62
+ torch.preserve_format,
63
+ # Converting non-contiguous layout to contiguous only updates
64
+ # strides of tensor. This is not visible on circle, so we can
65
+ # safely ignore this operation.
66
+ torch.contiguous_format,
67
+ ]:
68
+ continue
69
+
70
+ assert len(node.args) == 1
71
+
72
+ src = node.args[0]
73
+ assert isinstance(src, torch.fx.Node)
74
+
75
+ with graph.inserting_after(node):
76
+ node.replace_all_uses_with(src, propagate_meta=False)
77
+
78
+ modified = True
79
+ logger.debug(f"{node.name} is replaced with {src}")
80
+
81
+ graph.eliminate_dead_code()
82
+ graph.lint()
83
+ graph_module.recompile()
84
+
85
+ return PassResult(modified)
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.export import ExportedProgram
17
+
18
+ from tico.utils.passes import PassBase, PassResult
19
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+
21
+
22
+ assert_node_targets = [
23
+ torch.ops.aten._assert_tensor_metadata.default,
24
+ ]
25
+
26
+
27
+ @trace_graph_diff_on_pass
28
+ class RemoveRedundantAssertionNodes(PassBase):
29
+ """
30
+ This removes redundant assertion nodes.
31
+ - `aten.assert_tensor_meta.default`
32
+ """
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ def call(self, exported_program: ExportedProgram) -> PassResult:
38
+ graph_module = exported_program.graph_module
39
+ graph = graph_module.graph
40
+ modified = False
41
+ for node in graph.nodes:
42
+ if node.op == "call_function" and node.target in assert_node_targets:
43
+ graph.erase_node(node)
44
+ modified = True
45
+
46
+ graph.eliminate_dead_code()
47
+ graph.lint()
48
+ graph_module.recompile()
49
+
50
+ return PassResult(modified)