tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.export import ExportedProgram
17
+
18
+ from tico.utils.passes import PassBase, PassResult
19
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+ from tico.utils.utils import is_target_node
21
+
22
+
23
+ assert_node_targets = [
24
+ torch.ops.aten._assert_tensor_metadata.default,
25
+ ]
26
+
27
+
28
+ @trace_graph_diff_on_pass
29
+ class RemoveRedundantAssertionNodes(PassBase):
30
+ """
31
+ This removes redundant assertion nodes.
32
+ - `aten.assert_tensor_meta.default`
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def call(self, exported_program: ExportedProgram) -> PassResult:
39
+ graph_module = exported_program.graph_module
40
+ graph = graph_module.graph
41
+ modified = False
42
+ for node in graph.nodes:
43
+ if is_target_node(node, assert_node_targets):
44
+ graph.erase_node(node)
45
+ modified = True
46
+
47
+ graph.eliminate_dead_code()
48
+ graph.lint()
49
+ graph_module.recompile()
50
+
51
+ return PassResult(modified)
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
28
+ from tico.utils.validate_args_kwargs import ExpandArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class RemoveRedundantExpand(PassBase):
33
+ """
34
+ This pass removes redundant `aten.expand` operators where shapes of input and output are same.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+
40
+ def call(self, exported_program: ExportedProgram) -> PassResult:
41
+ logger = logging.getLogger(__name__)
42
+
43
+ graph_module = exported_program.graph_module
44
+ graph = graph_module.graph
45
+ modified = False
46
+ for node in graph.nodes:
47
+ if not is_target_node(node, ops.aten.expand):
48
+ continue
49
+
50
+ args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ input, size = args.input, args.size
52
+
53
+ input_shape = extract_shape(input)
54
+ if list(input_shape) != size:
55
+ continue
56
+
57
+ node.replace_all_uses_with(input, propagate_meta=False)
58
+
59
+ modified = True
60
+ logger.debug(f"{node.name} is replaced with {input.name}")
61
+
62
+ graph.eliminate_dead_code()
63
+ graph.lint()
64
+ graph_module.recompile()
65
+
66
+ return PassResult(modified)
@@ -0,0 +1,122 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import PermuteArgs
30
+
31
+
32
+ def _compose_permutation(dims1: list[int], dims2: list[int]):
33
+ """
34
+ Compose two permutation vectors.
35
+
36
+ Given y = x.permute(dims1) and z = y.permute(dims2),
37
+ the overall permutation p = dims2 ∘ dims1 is
38
+
39
+ p[i] = dims1[dims2[i]]
40
+ """
41
+ assert len(dims1) == len(
42
+ dims2
43
+ ), f"len(dims1): {len(dims1)}, len(dims2): {len(dims2)}"
44
+ return [dims1[i] for i in dims2]
45
+
46
+
47
+ def passes():
48
+ """
49
+ Return a list of passes that remove redundant `aten.permute` operators.
50
+
51
+ NOTE Both shape and stride of input/output should be same.
52
+ """
53
+ return [
54
+ RemoveRedundantPermutePattern1(),
55
+ ]
56
+
57
+
58
+ @trace_graph_diff_on_pass
59
+ class RemoveRedundantPermutePattern1(PassBase):
60
+ def __init__(self):
61
+ super().__init__()
62
+
63
+ def call(self, exported_program: ExportedProgram) -> PassResult:
64
+ """
65
+ [BEFORE]
66
+ (AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE)
67
+ [AFTER]
68
+ if OUT_SHAPE == (AxBxC):
69
+ (AxBxC)
70
+ else:
71
+ (AxBxC) - aten.permute (fused dims) - (OUT_SHAPE)
72
+
73
+ """
74
+ logger = logging.getLogger(__name__)
75
+
76
+ graph_module = exported_program.graph_module
77
+ graph = graph_module.graph
78
+ modified = False
79
+ for permute2 in graph.nodes:
80
+ if not is_target_node(permute2, ops.aten.permute):
81
+ continue
82
+
83
+ if len(permute2.users) != 1:
84
+ continue
85
+ permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
86
+ permute1, permute2_dims = permute2_args.input, permute2_args.dims
87
+
88
+ if not is_target_node(permute1, ops.aten.permute):
89
+ continue
90
+ if len(permute1.users) != 1:
91
+ continue
92
+ permute1_args = PermuteArgs(*permute1.args, **permute1.kwargs) # type: ignore[arg-type]
93
+ permute1_input, permute1_dims = permute1_args.input, permute1_args.dims
94
+
95
+ fused_dims = _compose_permutation(permute1_dims, permute2_dims)
96
+ identity = list(range(len(fused_dims)))
97
+
98
+ if fused_dims == identity:
99
+ # shape
100
+ permute1_input_shape = extract_shape(permute1_input)
101
+ permute2_shape = extract_shape(permute2)
102
+ assert permute1_input_shape == permute2_shape
103
+
104
+ permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
105
+ logger.debug(f"{permute1.name} and {permute2.name} are removed.")
106
+ else:
107
+ with graph.inserting_after(permute2):
108
+ new_args = (permute1_input, fused_dims)
109
+ fused_permute = create_node(
110
+ graph,
111
+ torch.ops.aten.permute.default,
112
+ args=new_args,
113
+ )
114
+ permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
115
+ logger.debug(f"{permute1.name} and {permute2.name} are fused.")
116
+ modified = True
117
+
118
+ graph.eliminate_dead_code()
119
+ graph.lint()
120
+ graph_module.recompile()
121
+
122
+ return PassResult(modified)
@@ -0,0 +1,436 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
29
+ from tico.utils.validate_args_kwargs import (
30
+ AddTensorArgs,
31
+ PermuteArgs,
32
+ ReshapeArgs,
33
+ SafeSoftmaxArgs,
34
+ SoftmaxArgs,
35
+ )
36
+
37
+
38
+ def passes():
39
+ """
40
+ Return list of passes that remove redundant `aten.reshape` operators.
41
+ """
42
+ return [
43
+ RemoveRedundantReshapePattern1(),
44
+ RemoveRedundantReshapePattern2(),
45
+ RemoveRedundantReshapePattern3(),
46
+ RemoveRedundantReshapePattern4(),
47
+ RemoveRedundantReshapePattern5(),
48
+ ]
49
+
50
+
51
+ @trace_graph_diff_on_pass
52
+ class RemoveRedundantReshapePattern1(PassBase):
53
+ mul_ops: List[torch._ops.OpOverload] = ops.aten.mul_scalar + ops.aten.mul_tensor
54
+
55
+ def __init__(self):
56
+ super().__init__()
57
+
58
+ def call(self, exported_program: ExportedProgram) -> PassResult:
59
+ """
60
+ [BEFORE]
61
+ `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)`
62
+ [AFTER]
63
+ `(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)`
64
+ """
65
+ logger = logging.getLogger(__name__)
66
+
67
+ graph_module = exported_program.graph_module
68
+ graph = graph_module.graph
69
+ modified = False
70
+ for reshape1 in graph.nodes:
71
+ ### first reshape
72
+ if not is_target_node(reshape1, ops.aten.reshape):
73
+ continue
74
+
75
+ # Assumes that other node do not use ops in the pattern for simplisity.
76
+ if len(reshape1.users) != 1:
77
+ continue
78
+ reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
79
+ reshape1_input = reshape1_args.input
80
+ # `(AxBxC) - aten.reshape` - (1xAxBxC)
81
+ if [1] + list(extract_shape(reshape1_input)) != list(
82
+ extract_shape(reshape1)
83
+ ):
84
+ continue
85
+
86
+ ### permute
87
+ permute = next(iter(reshape1.users))
88
+ if not is_target_node(permute, ops.aten.permute):
89
+ continue
90
+ if len(permute.users) != 1:
91
+ continue
92
+ permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
93
+ permute_input, permute_dims = permute_args.input, permute_args.dims
94
+ # (1xAxBxC) - `aten.permute` - (1xAxCxB)
95
+ if permute_dims != [0, 1, 3, 2]:
96
+ continue
97
+
98
+ ### mul
99
+ mul = next(iter(permute.users))
100
+ if not is_target_node(mul, RemoveRedundantReshapePattern1.mul_ops):
101
+ continue
102
+ if len(mul.users) != 1:
103
+ continue
104
+
105
+ ### second reshape
106
+ reshape2 = next(iter(mul.users))
107
+ if not is_target_node(reshape2, ops.aten.reshape):
108
+ continue
109
+ if len(reshape2.users) != 1:
110
+ continue
111
+ reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
112
+ reshape2_input = reshape2_args.input
113
+ # (1xAxCxB) - `aten.reshape - (AxCxB)
114
+ if list(extract_shape(reshape2_input)) != [1] + list(
115
+ extract_shape(reshape2)
116
+ ):
117
+ continue
118
+
119
+ ### remove redundant reshapes
120
+ # update permute (remove reshape1)
121
+ permute.args = (reshape1_input, [0, 2, 1])
122
+ set_new_meta_val(permute)
123
+ set_new_meta_val(mul)
124
+ # remove reshape2
125
+ reshape2.replace_all_uses_with(mul, propagate_meta=False)
126
+
127
+ modified = True
128
+ logger.debug(f"{reshape1.name} and {reshape2.name} are removed.")
129
+
130
+ graph.eliminate_dead_code()
131
+ graph.lint()
132
+ graph_module.recompile()
133
+
134
+ return PassResult(modified)
135
+
136
+
137
+ @trace_graph_diff_on_pass
138
+ class RemoveRedundantReshapePattern2(PassBase):
139
+ def __init__(self):
140
+ super().__init__()
141
+
142
+ def call(self, exported_program: ExportedProgram) -> PassResult:
143
+ """
144
+ [BEFORE]
145
+ `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))`
146
+ [AFTER]
147
+ `(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))`
148
+ """
149
+ logger = logging.getLogger(__name__)
150
+
151
+ graph_module = exported_program.graph_module
152
+ graph = graph_module.graph
153
+ modified = False
154
+ for reshape1 in graph.nodes:
155
+ ### first reshape
156
+ if not is_target_node(reshape1, ops.aten.reshape):
157
+ continue
158
+ if len(reshape1.users) != 1:
159
+ continue
160
+ reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
161
+ reshape1_input = reshape1_args.input
162
+ # `(AxBxC) - aten.reshape` - (1xAxBxC)
163
+ if [1] + list(extract_shape(reshape1_input)) != list(
164
+ extract_shape(reshape1)
165
+ ):
166
+ continue
167
+
168
+ ### permute
169
+ permute = next(iter(reshape1.users))
170
+ if not is_target_node(permute, ops.aten.permute):
171
+ continue
172
+ if len(permute.users) != 1:
173
+ continue
174
+ permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
175
+ permute_input, permute_dims = permute_args.input, permute_args.dims
176
+ # (1xAxBxC) - `aten.permute` - (Bx1xAxC)
177
+ if permute_dims != [2, 0, 1, 3]:
178
+ continue
179
+
180
+ ### second reshape
181
+ reshape2 = next(iter(permute.users))
182
+ if not is_target_node(reshape2, ops.aten.reshape):
183
+ continue
184
+ if len(reshape2.users) != 1:
185
+ continue
186
+ reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
187
+ reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
188
+ # (Bx1xAxC) - `aten.reshape - (Bx(A*C))
189
+ reshape2_input_shape = list(extract_shape(reshape2_input))
190
+ assert len(reshape2_input_shape) == 4
191
+ if list(extract_shape(reshape2)) != [
192
+ reshape2_input_shape[0],
193
+ (reshape2_input_shape[2] * reshape2_input_shape[3]),
194
+ ]:
195
+ continue
196
+
197
+ ### remove redundant reshapes
198
+ # update permute (remove reshape1)
199
+ permute.args = (reshape1_input, [1, 0, 2])
200
+ set_new_meta_val(permute)
201
+ reshape1.replace_all_uses_with(permute, propagate_meta=False)
202
+ # update reshape2 args
203
+ assert permute == reshape2_input
204
+ reshape2.args = (permute, reshape2_size)
205
+
206
+ modified = True
207
+ logger.debug(f"{reshape1.name} is removed.")
208
+
209
+ graph.eliminate_dead_code()
210
+ graph.lint()
211
+ graph_module.recompile()
212
+
213
+ return PassResult(modified)
214
+
215
+
216
+ @trace_graph_diff_on_pass
217
+ class RemoveRedundantReshapePattern3(PassBase):
218
+ def __init__(self):
219
+ super().__init__()
220
+
221
+ def call(self, exported_program: ExportedProgram) -> PassResult:
222
+ """
223
+ [BEFORE]
224
+ (AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC)
225
+ (reshape_2) (add) (softmax) (reshape_1)
226
+ (AxBxC) - aten.reshape - (1xAxBxC) /
227
+ (reshape_3)
228
+ [AFTER]
229
+ (AxBxC) - aten.add - (AxBxC) - aten.softmax - (AxBxC)
230
+ (AxBxC) / (add) (softmax)
231
+ """
232
+ logger = logging.getLogger(__name__)
233
+
234
+ graph_module = exported_program.graph_module
235
+ graph = graph_module.graph
236
+ modified = False
237
+ for reshape_1 in graph.nodes:
238
+ # reshape_1
239
+ if not is_target_node(reshape_1, ops.aten.reshape):
240
+ continue
241
+ reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type]
242
+ softmax = reshape_1_args.input
243
+
244
+ # softmax
245
+ softmax_args = None
246
+ if not is_target_node(softmax, ops.aten.softmax):
247
+ continue
248
+ if softmax.target == torch.ops.aten._softmax.default:
249
+ softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
250
+ elif softmax.target == torch.ops.aten._safe_softmax.default:
251
+ softmax_args = SafeSoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
252
+ else:
253
+ raise RuntimeError("Invalid input")
254
+ assert softmax_args is not None
255
+ add, softmax_dim = (
256
+ softmax_args.input,
257
+ softmax_args.dim,
258
+ )
259
+ softmax_shape = extract_shape(softmax)
260
+ # TODO support other dimension
261
+ if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
262
+ continue
263
+
264
+ # add
265
+ if not add.target in ops.aten.add:
266
+ continue
267
+ add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
268
+ reshape_2, reshape_3 = add_args.input, add_args.other
269
+ assert isinstance(reshape_2, torch.fx.Node), type(reshape_2)
270
+ assert isinstance(reshape_3, torch.fx.Node), type(reshape_3)
271
+
272
+ # reshape_2
273
+ if not reshape_2.op == "call_function":
274
+ continue
275
+ if not reshape_2.target in ops.aten.reshape:
276
+ continue
277
+ reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
278
+ reshape_2_input = reshape_2_args.input
279
+ assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input)
280
+ # reshape_3
281
+ if not reshape_3.op == "call_function":
282
+ continue
283
+ if not reshape_3.target in ops.aten.reshape:
284
+ continue
285
+ reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
286
+ reshape_3_input = reshape_3_args.input
287
+ assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input)
288
+
289
+ # Check condition
290
+ reshape_2_input_shape = extract_shape(reshape_2_input)
291
+ reshape_3_input_shape = extract_shape(reshape_3_input)
292
+ if not broadcastable(reshape_2_input_shape, reshape_3_input_shape):
293
+ continue
294
+ reshape_1_shape = extract_shape(reshape_1)
295
+ if (
296
+ reshape_2_input_shape != reshape_1_shape
297
+ and reshape_3_input_shape != reshape_1_shape
298
+ ):
299
+ continue
300
+ # Make sure the softmax axis length is unchanged.
301
+ if softmax_shape[-1] != reshape_1_shape[-1]:
302
+ continue
303
+ # Assume `aten.add` and `aten.softmax` have only one user.
304
+ if len(add.users) != 1:
305
+ continue
306
+ if len(softmax.users) != 1:
307
+ continue
308
+
309
+ # Update add
310
+ add.args = (reshape_2_input, reshape_3_input)
311
+ set_new_meta_val(add)
312
+ # Update softmax
313
+ if softmax_dim == len(softmax_shape) - 1:
314
+ softmax.update_arg(1, -1) # (index, last_dim)
315
+ set_new_meta_val(softmax)
316
+
317
+ reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
318
+ modified = True
319
+ logger.debug(
320
+ f"{reshape_2.name}, {reshape_3.name} and {reshape_1.name} are removed."
321
+ )
322
+
323
+ graph.eliminate_dead_code()
324
+ graph.lint()
325
+ graph_module.recompile()
326
+
327
+ return PassResult(modified)
328
+
329
+
330
+ @trace_graph_diff_on_pass
331
+ class RemoveRedundantReshapePattern4(PassBase):
332
+ def __init__(self):
333
+ super().__init__()
334
+
335
+ def call(self, exported_program: ExportedProgram) -> PassResult:
336
+ """
337
+ NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors.
338
+ What this pattern aims to remove is that the consecutive `aten.reshape` ops.
339
+ [BEFORE]
340
+ (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC')
341
+ [AFTER]
342
+ (AxBxC) - aten.reshape - (A'xB''xC')
343
+ """
344
+ logger = logging.getLogger(__name__)
345
+
346
+ graph_module = exported_program.graph_module
347
+ graph = graph_module.graph
348
+ modified = False
349
+ for reshape1 in graph.nodes:
350
+ # reshape_1
351
+ if not is_target_node(reshape1, ops.aten.reshape):
352
+ continue
353
+
354
+ reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
355
+ reshape1_input, size = reshape1_args.input, reshape1_args.shape
356
+ assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
357
+ assert isinstance(size, list), type(size)
358
+ for s in size:
359
+ assert isinstance(s, int), type(s)
360
+
361
+ if not len(reshape1.users) == 1:
362
+ continue
363
+
364
+ # reshape_2
365
+ reshape2 = next(iter(reshape1.users))
366
+ if not is_target_node(reshape2, ops.aten.reshape):
367
+ continue
368
+
369
+ reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
370
+ reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
371
+ assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
372
+ assert isinstance(reshape2_size, list), type(reshape2_size)
373
+ for s in reshape2_size:
374
+ assert isinstance(s, int), type(s)
375
+
376
+ with graph.inserting_before(reshape1):
377
+ fused_reshape = create_node(
378
+ graph,
379
+ reshape1.target,
380
+ (reshape1_input, reshape2_size),
381
+ )
382
+
383
+ reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
384
+
385
+ modified = True
386
+ logger.debug(
387
+ f"{reshape1.name} and {reshape2.name} are fused to {fused_reshape.name}"
388
+ )
389
+
390
+ graph.eliminate_dead_code()
391
+ graph.lint()
392
+ graph_module.recompile()
393
+
394
+ return PassResult(modified)
395
+
396
+
397
+ @trace_graph_diff_on_pass
398
+ class RemoveRedundantReshapePattern5(PassBase):
399
+ def __init__(self):
400
+ super().__init__()
401
+
402
+ def call(self, exported_program: ExportedProgram) -> PassResult:
403
+ """
404
+ [BEFORE]
405
+ (AxBxC) - aten.reshape - (AxBxC)
406
+ [AFTER]
407
+ (AxBxC)
408
+ """
409
+ logger = logging.getLogger(__name__)
410
+
411
+ graph_module = exported_program.graph_module
412
+ graph = graph_module.graph
413
+ modified = False
414
+
415
+ for node in graph.nodes:
416
+ if not is_target_node(node, ops.aten.reshape):
417
+ continue
418
+
419
+ args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
420
+ output_shape = args.shape
421
+ input_shape = list(extract_shape(args.input))
422
+
423
+ if output_shape != input_shape:
424
+ continue
425
+
426
+ with graph.inserting_after(node):
427
+ node.replace_all_uses_with(args.input, propagate_meta=False)
428
+
429
+ modified = True
430
+ logger.debug(f"{node.name} is replaced with {args.input}")
431
+
432
+ graph.eliminate_dead_code()
433
+ graph.lint()
434
+ graph_module.recompile()
435
+
436
+ return PassResult(modified)