tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+
28
+
29
+ @trace_graph_diff_on_pass
30
+ class RemoveRedundantExpand(PassBase):
31
+ """
32
+ This pass removes redundant `aten.expand` operators where shapes of input and output are same.
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def call(self, exported_program: ExportedProgram) -> PassResult:
39
+ logger = logging.getLogger(__name__)
40
+
41
+ graph_module = exported_program.graph_module
42
+ graph = graph_module.graph
43
+ modified = False
44
+ for node in graph.nodes:
45
+ if not node.op == "call_function":
46
+ continue
47
+
48
+ if not node.target in ops.aten.expand:
49
+ continue
50
+
51
+ assert len(node.args) == 2
52
+
53
+ input, size = list(node.args)
54
+ assert isinstance(input, torch.fx.Node), type(input)
55
+ assert isinstance(size, list), type(size)
56
+
57
+ input_shape = extract_shape(input)
58
+ if list(input_shape) != size:
59
+ continue
60
+
61
+ node.replace_all_uses_with(input, propagate_meta=False)
62
+
63
+ modified = True
64
+ logger.debug(f"{node.name} is replaced with {input.name}")
65
+
66
+ graph.eliminate_dead_code()
67
+ graph.lint()
68
+ graph_module.recompile()
69
+
70
+ return PassResult(modified)
@@ -0,0 +1,102 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape, extract_stride
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+
28
+
29
+ def passes():
30
+ """
31
+ Return a list of passes that remove redundant `aten.permute` operators.
32
+
33
+ NOTE Both shape and stride of input/output should be same.
34
+ """
35
+ return [
36
+ RemoveRedundantPermutePattern1(),
37
+ ]
38
+
39
+
40
+ @trace_graph_diff_on_pass
41
+ class RemoveRedundantPermutePattern1(PassBase):
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ """
47
+ [BEFORE]
48
+ (AxBxC) - aten.permute - aten.permute - (AxBxC)
49
+ [AFTER]
50
+ (AxBxC)
51
+ """
52
+ logger = logging.getLogger(__name__)
53
+
54
+ graph_module = exported_program.graph_module
55
+ graph = graph_module.graph
56
+ modified = False
57
+ for permute2 in graph.nodes:
58
+ if not permute2.op == "call_function":
59
+ continue
60
+ if not permute2.target in ops.aten.permute:
61
+ continue
62
+ if len(permute2.users) != 1:
63
+ continue
64
+ assert len(permute2.args) == 2
65
+ permute1, permute2_dims = permute2.args
66
+ assert isinstance(permute1, torch.fx.Node), type(permute1)
67
+ assert isinstance(permute2_dims, list), type(permute2_dims)
68
+ for dim in permute2_dims:
69
+ assert isinstance(dim, int), type(dim)
70
+
71
+ if not permute1.target in ops.aten.permute:
72
+ continue
73
+ if len(permute1.users) != 1:
74
+ continue
75
+ assert len(permute1.args) == 2
76
+ permute1_input, permute1_dims = permute1.args
77
+ assert isinstance(permute1_input, torch.fx.Node), type(permute1_input)
78
+ assert isinstance(permute1_dims, list), type(permute1_dims)
79
+ for dim in permute1_dims:
80
+ assert isinstance(dim, int), type(dim)
81
+
82
+ # shape
83
+ permute1_input_shape = extract_shape(permute1_input)
84
+ permute2_shape = extract_shape(permute2)
85
+ if permute1_input_shape != permute2_shape:
86
+ continue
87
+ # stride
88
+ permute1_input_stride = extract_stride(permute1_input)
89
+ permute2_stride = extract_stride(permute2)
90
+ if permute1_input_stride != permute2_stride:
91
+ continue
92
+
93
+ permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
94
+
95
+ modified = True
96
+ logger.debug(f"{permute1.name} and {permute2.name} are removed.")
97
+
98
+ graph.eliminate_dead_code()
99
+ graph.lint()
100
+ graph_module.recompile()
101
+
102
+ return PassResult(modified)
@@ -0,0 +1,431 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import set_new_meta_val
28
+ from tico.utils.validate_args_kwargs import ReshapeArgs
29
+
30
+
31
+ def passes():
32
+ """
33
+ Return list of passes that remove redundant `aten.reshape` operators.
34
+ """
35
+ return [
36
+ RemoveRedundantReshapePattern1(),
37
+ RemoveRedundantReshapePattern2(),
38
+ RemoveRedundantReshapePattern3(),
39
+ RemoveRedundantReshapePattern4(),
40
+ RemoveRedundantReshapePattern5(),
41
+ ]
42
+
43
+
44
+ @trace_graph_diff_on_pass
45
+ class RemoveRedundantReshapePattern1(PassBase):
46
+ mul_ops: List[torch._ops.OpOverload] = ops.aten.mul_scalar + ops.aten.mul_tensor
47
+
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def call(self, exported_program: ExportedProgram) -> PassResult:
52
+ """
53
+ [BEFORE]
54
+ `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)`
55
+ [AFTER]
56
+ `(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)`
57
+ """
58
+ logger = logging.getLogger(__name__)
59
+
60
+ graph_module = exported_program.graph_module
61
+ graph = graph_module.graph
62
+ modified = False
63
+ for reshape1 in graph.nodes:
64
+ if not reshape1.op == "call_function":
65
+ continue
66
+
67
+ ### first reshape
68
+ if not reshape1.target in ops.aten.reshape:
69
+ continue
70
+ # Assumes that other node do not use ops in the pattern for simplisity.
71
+ if len(reshape1.users) != 1:
72
+ continue
73
+ assert len(reshape1.args) == 2, len(reshape1.args)
74
+ reshape1_input, reshape1_size = reshape1.args
75
+ # `(AxBxC) - aten.reshape` - (1xAxBxC)
76
+ if [1] + list(extract_shape(reshape1_input)) != list(
77
+ extract_shape(reshape1)
78
+ ):
79
+ continue
80
+
81
+ ### permute
82
+ permute = next(iter(reshape1.users))
83
+ if not permute.target in ops.aten.permute:
84
+ continue
85
+ if len(permute.users) != 1:
86
+ continue
87
+ assert len(permute.args) == 2, len(permute.args)
88
+ permute_input, permute_dims = permute.args
89
+ # (1xAxBxC) - `aten.permute` - (1xAxCxB)
90
+ if permute_dims != [0, 1, 3, 2]:
91
+ continue
92
+
93
+ ### mul
94
+ mul = next(iter(permute.users))
95
+ if not mul.target in RemoveRedundantReshapePattern1.mul_ops:
96
+ continue
97
+ if len(mul.users) != 1:
98
+ continue
99
+
100
+ ### second reshape
101
+ reshape2 = next(iter(mul.users))
102
+ if not reshape2.target in ops.aten.reshape:
103
+ continue
104
+ if len(reshape2.users) != 1:
105
+ continue
106
+ reshape2_input, reshape2_size = reshape2.args
107
+ # (1xAxCxB) - `aten.reshape - (AxCxB)
108
+ if list(extract_shape(reshape2_input)) != [1] + list(
109
+ extract_shape(reshape2)
110
+ ):
111
+ continue
112
+
113
+ ### remove redundant reshapes
114
+ # update permute (remove reshape1)
115
+ permute.args = (reshape1_input, [0, 2, 1])
116
+ set_new_meta_val(permute)
117
+ set_new_meta_val(mul)
118
+ # remove reshape2
119
+ reshape2.replace_all_uses_with(mul, propagate_meta=False)
120
+
121
+ modified = True
122
+ logger.debug(f"{reshape1.name} and {reshape2.name} are removed.")
123
+
124
+ graph.eliminate_dead_code()
125
+ graph.lint()
126
+ graph_module.recompile()
127
+
128
+ return PassResult(modified)
129
+
130
+
131
+ @trace_graph_diff_on_pass
132
+ class RemoveRedundantReshapePattern2(PassBase):
133
+ def __init__(self):
134
+ super().__init__()
135
+
136
+ def call(self, exported_program: ExportedProgram) -> PassResult:
137
+ """
138
+ [BEFORE]
139
+ `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))`
140
+ [AFTER]
141
+ `(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))`
142
+ """
143
+ logger = logging.getLogger(__name__)
144
+
145
+ graph_module = exported_program.graph_module
146
+ graph = graph_module.graph
147
+ modified = False
148
+ for reshape1 in graph.nodes:
149
+ if not reshape1.op == "call_function":
150
+ continue
151
+
152
+ ### first reshape
153
+ if not reshape1.target in ops.aten.reshape:
154
+ continue
155
+ if len(reshape1.users) != 1:
156
+ continue
157
+ assert len(reshape1.args) == 2, len(reshape1.args)
158
+ reshape1_input, reshape1_size = reshape1.args
159
+ # `(AxBxC) - aten.reshape` - (1xAxBxC)
160
+ if [1] + list(extract_shape(reshape1_input)) != list(
161
+ extract_shape(reshape1)
162
+ ):
163
+ continue
164
+
165
+ ### permute
166
+ permute = next(iter(reshape1.users))
167
+ if not permute.target in ops.aten.permute:
168
+ continue
169
+ if len(permute.users) != 1:
170
+ continue
171
+ assert len(permute.args) == 2, len(permute.args)
172
+ permute_input, permute_dims = permute.args
173
+ # (1xAxBxC) - `aten.permute` - (Bx1xAxC)
174
+ if permute_dims != [2, 0, 1, 3]:
175
+ continue
176
+
177
+ ### second reshape
178
+ reshape2 = next(iter(permute.users))
179
+ if not reshape2.target in ops.aten.reshape:
180
+ continue
181
+ if len(reshape2.users) != 1:
182
+ continue
183
+ reshape2_input, reshape2_size = reshape2.args
184
+ # (Bx1xAxC) - `aten.reshape - (Bx(A*C))
185
+ reshape2_input_shape = list(extract_shape(reshape2_input))
186
+ assert len(reshape2_input_shape) == 4
187
+ if list(extract_shape(reshape2)) != [
188
+ reshape2_input_shape[0],
189
+ (reshape2_input_shape[2] * reshape2_input_shape[3]),
190
+ ]:
191
+ continue
192
+
193
+ ### remove redundant reshapes
194
+ # update permute (remove reshape1)
195
+ permute.args = (reshape1_input, [1, 0, 2])
196
+ set_new_meta_val(permute)
197
+ reshape1.replace_all_uses_with(permute, propagate_meta=False)
198
+ # update reshape2 args
199
+ assert permute == reshape2_input
200
+ reshape2.args = (permute, reshape2_size)
201
+
202
+ modified = True
203
+ logger.debug(f"{reshape1.name} is removed.")
204
+
205
+ graph.eliminate_dead_code()
206
+ graph.lint()
207
+ graph_module.recompile()
208
+
209
+ return PassResult(modified)
210
+
211
+
212
+ @trace_graph_diff_on_pass
213
+ class RemoveRedundantReshapePattern3(PassBase):
214
+ def __init__(self):
215
+ super().__init__()
216
+
217
+ def call(self, exported_program: ExportedProgram) -> PassResult:
218
+ """
219
+ [BEFORE]
220
+ (AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC)
221
+ (reshape_2) (add) (softmax) (reshape_1)
222
+ (AxBxC) - aten.reshape - (1xAxBxC) /
223
+ (reshape_3)
224
+ [AFTER]
225
+ (AxBxC) - aten.add - (AxBxC) - aten.softmax - (AxBxC)
226
+ (AxBxC) / (add) (softmax)
227
+ """
228
+ logger = logging.getLogger(__name__)
229
+
230
+ graph_module = exported_program.graph_module
231
+ graph = graph_module.graph
232
+ modified = False
233
+ for reshape_1 in graph.nodes:
234
+ assert isinstance(reshape_1, torch.fx.Node), type(reshape_1)
235
+ # reshape_1
236
+ if not reshape_1.op == "call_function":
237
+ continue
238
+ if not reshape_1.target in ops.aten.reshape:
239
+ continue
240
+ assert len(reshape_1.args) == 2, len(reshape_1.args)
241
+ softmax, reshape_1_size = reshape_1.args
242
+
243
+ # softmax
244
+ assert isinstance(softmax, torch.fx.Node), type(softmax)
245
+ if not softmax.op == "call_function":
246
+ continue
247
+ if not softmax.target in ops.aten.softmax:
248
+ continue
249
+ assert len(softmax.args) == 3, len(softmax.args)
250
+ add, softmax_dim, softmax_half_to_float = softmax.args
251
+ assert isinstance(add, torch.fx.Node), type(add)
252
+ assert isinstance(softmax_dim, int), type(softmax_dim)
253
+ assert isinstance(softmax_half_to_float, bool), type(softmax_half_to_float)
254
+ softmax_shape = extract_shape(softmax)
255
+ # TODO support other dimension
256
+ if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
257
+ continue
258
+
259
+ # add
260
+ if not add.target in ops.aten.add:
261
+ continue
262
+ assert len(add.args) == 2, len(add.args)
263
+ reshape_2, reshape_3 = add.args
264
+ assert isinstance(reshape_2, torch.fx.Node), type(reshape_2)
265
+ assert isinstance(reshape_3, torch.fx.Node), type(reshape_3)
266
+
267
+ # reshape_2
268
+ if not reshape_2.op == "call_function":
269
+ continue
270
+ if not reshape_2.target in ops.aten.reshape:
271
+ continue
272
+ assert len(reshape_2.args) == 2, len(reshape_2.args)
273
+ reshape_2_input, reshape_2_size = reshape_2.args
274
+ assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input)
275
+ # reshape_3
276
+ if not reshape_3.op == "call_function":
277
+ continue
278
+ if not reshape_3.target in ops.aten.reshape:
279
+ continue
280
+ assert len(reshape_3.args) == 2, len(reshape_3.args)
281
+ reshape_3_input, reshape_3_size = reshape_3.args
282
+ assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input)
283
+
284
+ # Check condition
285
+ reshape_2_input_shape = extract_shape(reshape_2_input)
286
+ reshape_3_input_shape = extract_shape(reshape_3_input)
287
+ if reshape_2_input_shape != reshape_3_input_shape:
288
+ continue
289
+ reshape_1_shape = extract_shape(reshape_1)
290
+ if reshape_2_input_shape != reshape_1_shape:
291
+ continue
292
+ # Assume `aten.add` and `aten.softmax` have only one user.
293
+ if len(add.users) != 1:
294
+ continue
295
+ if len(softmax.users) != 1:
296
+ continue
297
+
298
+ # Update add
299
+ add.args = (reshape_2_input, reshape_3_input)
300
+ set_new_meta_val(add)
301
+ # Update softmax
302
+ if softmax_dim == len(softmax_shape) - 1:
303
+ updated_dim = len(extract_shape(reshape_2_input)) - 1
304
+ softmax.args = (add, updated_dim, softmax_half_to_float)
305
+ set_new_meta_val(softmax)
306
+
307
+ reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
308
+ modified = True
309
+ logger.debug(
310
+ f"{reshape_2.name}, {reshape_3.name} and {reshape_1.name} are removed."
311
+ )
312
+
313
+ graph.eliminate_dead_code()
314
+ graph.lint()
315
+ graph_module.recompile()
316
+
317
+ return PassResult(modified)
318
+
319
+
320
+ @trace_graph_diff_on_pass
321
+ class RemoveRedundantReshapePattern4(PassBase):
322
+ def __init__(self):
323
+ super().__init__()
324
+
325
+ def call(self, exported_program: ExportedProgram) -> PassResult:
326
+ """
327
+ NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors.
328
+ What this pattern aims to remove is that the consecutive `aten.reshape` ops.
329
+ [BEFORE]
330
+ (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC')
331
+ [AFTER]
332
+ (AxBxC) - aten.reshape - (A'xB''xC')
333
+ """
334
+ logger = logging.getLogger(__name__)
335
+
336
+ graph_module = exported_program.graph_module
337
+ graph = graph_module.graph
338
+ modified = False
339
+ for reshape1 in graph.nodes:
340
+ # reshape_1
341
+ if not reshape1.op == "call_function":
342
+ continue
343
+ if not reshape1.target in ops.aten.reshape:
344
+ continue
345
+ assert len(reshape1.args) == 2, len(reshape1.args)
346
+
347
+ reshape1_input, size = list(reshape1.args)
348
+ assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
349
+ assert isinstance(size, list), type(size)
350
+ for s in size:
351
+ assert isinstance(s, int), type(s)
352
+
353
+ if not len(reshape1.users) == 1:
354
+ continue
355
+
356
+ # reshape_2
357
+ reshape2 = next(iter(reshape1.users))
358
+ if not reshape2.op == "call_function":
359
+ continue
360
+ if not reshape2.target in ops.aten.reshape:
361
+ continue
362
+ assert len(reshape2.args) == 2, len(reshape2.args)
363
+
364
+ reshape2_input, reshape2_size = list(reshape2.args)
365
+ assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
366
+ assert isinstance(reshape2_size, list), type(reshape2_size)
367
+ for s in reshape2_size:
368
+ assert isinstance(s, int), type(s)
369
+
370
+ with graph.inserting_before(reshape1):
371
+ fused_reshape = graph.call_function(
372
+ reshape1.target, (reshape1_input, reshape2_size)
373
+ )
374
+
375
+ reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
376
+
377
+ modified = True
378
+ logger.debug(
379
+ f"{reshape1.name} and {reshape2.name} are fused to {fused_reshape.name}"
380
+ )
381
+
382
+ graph.eliminate_dead_code()
383
+ graph.lint()
384
+ graph_module.recompile()
385
+
386
+ return PassResult(modified)
387
+
388
+
389
+ @trace_graph_diff_on_pass
390
+ class RemoveRedundantReshapePattern5(PassBase):
391
+ def __init__(self):
392
+ super().__init__()
393
+
394
+ def call(self, exported_program: ExportedProgram) -> PassResult:
395
+ """
396
+ [BEFORE]
397
+ (AxBxC) - aten.reshape - (AxBxC)
398
+ [AFTER]
399
+ (AxBxC)
400
+ """
401
+ logger = logging.getLogger(__name__)
402
+
403
+ graph_module = exported_program.graph_module
404
+ graph = graph_module.graph
405
+ modified = False
406
+
407
+ for node in graph.nodes:
408
+ if not node.op == "call_function":
409
+ continue
410
+
411
+ if not node.target in ops.aten.reshape:
412
+ continue
413
+
414
+ args = ReshapeArgs(*node.args)
415
+ output_shape = args.size
416
+ input_shape = list(extract_shape(args.input))
417
+
418
+ if output_shape != input_shape:
419
+ continue
420
+
421
+ with graph.inserting_after(node):
422
+ node.replace_all_uses_with(args.input, propagate_meta=False)
423
+
424
+ modified = True
425
+ logger.debug(f"{node.name} is replaced with {args.input}")
426
+
427
+ graph.eliminate_dead_code()
428
+ graph.lint()
429
+ graph_module.recompile()
430
+
431
+ return PassResult(modified)
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.passes import ops
18
+ from tico.serialize.circle_mapping import extract_shape
19
+ from tico.utils import logging
20
+ from tico.utils.passes import PassBase, PassResult
21
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
22
+ from tico.utils.validate_args_kwargs import SliceArgs
23
+
24
+
25
+ @trace_graph_diff_on_pass
26
+ class RemoveRedundantSlice(PassBase):
27
+ """
28
+ This pass removes redundant slice operators where shapes of input and output are same.
29
+ """
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ def call(self, exported_program: ExportedProgram) -> PassResult:
35
+ logger = logging.getLogger(__name__)
36
+
37
+ graph_module = exported_program.graph_module
38
+ graph = graph_module.graph
39
+ modified = False
40
+ for node in graph.nodes:
41
+ if not node.op == "call_function":
42
+ continue
43
+
44
+ if not node.target in ops.aten.slice:
45
+ continue
46
+
47
+ args = SliceArgs(*node.args, **node.kwargs)
48
+
49
+ input_shape = extract_shape(args.input)
50
+ node_shape = extract_shape(node)
51
+
52
+ if input_shape != node_shape:
53
+ continue
54
+
55
+ node.replace_all_uses_with(args.input, propagate_meta=False)
56
+
57
+ modified = True
58
+ logger.debug(f"{node.name} is replaced with {args.input.name}")
59
+
60
+ graph.eliminate_dead_code()
61
+ graph.lint()
62
+ graph_module.recompile()
63
+
64
+ return PassResult(modified)