tico 0.1.0.dev250609__py3-none-any.whl → 0.1.0.dev250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tico/__init__.py +1 -1
  2. tico/passes/cast_aten_where_arg_type.py +65 -62
  3. tico/passes/cast_mixed_type_args.py +2 -5
  4. tico/passes/convert_conv1d_to_conv2d.py +3 -4
  5. tico/passes/convert_repeat_to_expand_copy.py +5 -9
  6. tico/passes/decompose_addmm.py +41 -48
  7. tico/passes/decompose_batch_norm.py +97 -99
  8. tico/passes/decompose_fake_quantize.py +4 -6
  9. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -4
  10. tico/passes/decompose_group_norm.py +8 -7
  11. tico/passes/decompose_grouped_conv2d.py +2 -3
  12. tico/passes/decompose_slice_scatter.py +2 -4
  13. tico/passes/extract_dtype_kwargs.py +2 -1
  14. tico/passes/fuse_leading_unsqueeze_reshape.py +107 -0
  15. tico/passes/fuse_redundant_reshape_to_mean.py +3 -5
  16. tico/passes/legalize_causal_mask_value.py +2 -7
  17. tico/passes/legalize_predefined_layout_operators.py +2 -3
  18. tico/passes/lower_pow2_to_mul.py +5 -7
  19. tico/passes/lower_to_resize_nearest_neighbor.py +6 -10
  20. tico/passes/lower_to_slice.py +3 -9
  21. tico/passes/merge_consecutive_cat.py +2 -4
  22. tico/passes/remove_nop.py +2 -3
  23. tico/passes/remove_redundant_assert_nodes.py +2 -1
  24. tico/passes/remove_redundant_expand.py +5 -9
  25. tico/passes/remove_redundant_permute.py +6 -5
  26. tico/passes/remove_redundant_reshape.py +17 -34
  27. tico/passes/remove_redundant_slice.py +2 -4
  28. tico/passes/remove_redundant_to_copy.py +2 -4
  29. tico/passes/segment_index_select.py +2 -4
  30. tico/serialize/operators/op_where.py +2 -2
  31. tico/utils/convert.py +2 -0
  32. tico/utils/utils.py +26 -0
  33. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/METADATA +1 -1
  34. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/RECORD +38 -37
  35. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/LICENSE +0 -0
  36. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/WHEEL +0 -0
  37. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/entry_points.txt +0 -0
  38. {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
- from tico.utils.utils import enforce_type
28
+ from tico.utils.utils import enforce_type, is_target_node
29
29
 
30
30
 
31
31
  @trace_graph_diff_on_pass
@@ -82,9 +82,7 @@ class DecomposeSliceScatter(PassBase):
82
82
  modified = False
83
83
 
84
84
  for node in graph.nodes:
85
- if node.op != "call_function":
86
- continue
87
- if node.target != torch.ops.aten.slice_scatter.default:
85
+ if not is_target_node(node, torch.ops.aten.slice_scatter.default):
88
86
  continue
89
87
 
90
88
  @enforce_type
@@ -23,6 +23,7 @@ from torch.utils import _pytree as pytree
23
23
  from tico.utils import logging
24
24
  from tico.utils.passes import PassBase, PassResult
25
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
26
27
 
27
28
 
28
29
  def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
@@ -107,7 +108,7 @@ class ExtractDtypeKwargsPass(PassBase):
107
108
  graph: torch.fx.Graph = graph_module.graph
108
109
  modified = False
109
110
  for node in graph.nodes:
110
- if not node.op == "call_function" or node.target not in self.target_ops:
111
+ if not is_target_node(node, list(self.target_ops.keys())):
111
112
  continue
112
113
  if "dtype" not in node.kwargs:
113
114
  continue
@@ -0,0 +1,107 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Sequence
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.passes import ops
21
+ from tico.serialize.circle_mapping import extract_shape
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
26
+ from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
27
+
28
+
29
+ def _is_leading_unsqueeze(target: Sequence[int], permuted: Sequence[int]) -> bool:
30
+ """
31
+ True if `target` == [1]*k + permuted, k>=1.
32
+ """
33
+ k = len(target) - len(permuted)
34
+ return (
35
+ k > 0 and all(d == 1 for d in target[:k]) and list(target[k:]) == list(permuted)
36
+ )
37
+
38
+
39
+ @trace_graph_diff_on_pass
40
+ class FuseLeadingUnsqueezeReshape(PassBase):
41
+ """
42
+ Fuse reshape → permute → reshape where the second reshape only
43
+ prepends one-sized dims (unsqueeze) to the permuted tensor.
44
+
45
+ [BEFORE]
46
+ x - aten.reshape(s1) - aten.permute(p) - aten.reshape([1]*k + p(s1))
47
+ [AFTER]
48
+ x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p])
49
+ """
50
+
51
+ def call(self, ep: ExportedProgram) -> PassResult:
52
+ logger = logging.getLogger(__name__)
53
+
54
+ gm = ep.graph_module
55
+ graph = gm.graph
56
+ modified = False
57
+ for reshape_back in graph.nodes:
58
+ if not is_target_node(reshape_back, ops.aten.reshape):
59
+ continue
60
+ reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
61
+ permute, reshape_back_size = reshape_back_args.input, reshape_back_args.size
62
+
63
+ if not is_target_node(permute, ops.aten.permute):
64
+ continue
65
+ permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
66
+ reshape_front, permute_dims = permute_args.input, permute_args.dims
67
+
68
+ if not is_target_node(reshape_front, ops.aten.reshape):
69
+ continue
70
+ reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
71
+ reshape_front_input, reshape_front_size = (
72
+ reshape_front_args.input,
73
+ reshape_front_args.size,
74
+ )
75
+
76
+ # ---- condition: only leading unsqueeze ------------------------
77
+ back_shape = extract_shape(reshape_back)
78
+ permute_shape = extract_shape(permute)
79
+
80
+ if not _is_leading_unsqueeze(back_shape, permute_shape):
81
+ continue
82
+
83
+ # ---- create new reshape & new permute -------------------------
84
+ k = len(back_shape) - len(permute_shape)
85
+ with graph.inserting_before(permute):
86
+ new_shape = [1] * k + list(reshape_front_size)
87
+ r_new = graph.call_function(
88
+ torch.ops.aten.reshape.default,
89
+ args=(reshape_front_input, new_shape),
90
+ )
91
+ new_p_dims = list(range(k)) + [
92
+ d + k for d in permute_dims
93
+ ] # shift by k
94
+ p_new = graph.call_function(
95
+ torch.ops.aten.permute.default, args=(r_new, new_p_dims)
96
+ )
97
+
98
+ reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
99
+ modified = True
100
+ logger.debug(f"{reshape_back.name} is fused to {r_new.name}")
101
+
102
+ if modified:
103
+ graph.eliminate_dead_code()
104
+ graph.lint()
105
+ gm.recompile()
106
+
107
+ return PassResult(modified)
@@ -25,6 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
25
25
  from tico.utils import logging
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
28
29
 
29
30
 
30
31
  @trace_graph_diff_on_pass
@@ -45,10 +46,7 @@ class FuseRedundantReshapeToMean(PassBase):
45
46
  graph = graph_module.graph
46
47
  modified = False
47
48
  for node in graph.nodes:
48
- if not node.op == "call_function":
49
- continue
50
-
51
- if node.target != torch.ops.aten.mean.dim:
49
+ if not is_target_node(node, torch.ops.aten.mean.dim):
52
50
  continue
53
51
 
54
52
  # If mean is being used in other nodes, do not fuse it.
@@ -56,7 +54,7 @@ class FuseRedundantReshapeToMean(PassBase):
56
54
  continue
57
55
 
58
56
  user_node = next(iter(node.users))
59
- if user_node.target not in ops.aten.reshape:
57
+ if not is_target_node(user_node, ops.aten.reshape):
60
58
  continue
61
59
 
62
60
  mean_args, mean_kwargs = pytree.tree_map_only(
@@ -20,10 +20,10 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.passes import ops
23
-
24
23
  from tico.utils import logging
25
24
  from tico.utils.passes import PassBase, PassResult
26
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
27
  from tico.utils.validate_args_kwargs import AddTensorArgs
28
28
 
29
29
 
@@ -53,14 +53,9 @@ class LegalizeCausalMaskValue(PassBase):
53
53
  graph = graph_module.graph
54
54
  modified = False
55
55
  for node in graph.nodes:
56
- if not node.op == "call_function":
57
- continue
58
-
59
- if not node.target in ops.aten.add:
56
+ if not is_target_node(node, ops.aten.add):
60
57
  continue
61
58
 
62
- assert len(node.args) == 2
63
-
64
59
  args = AddTensorArgs(*node.args, **node.kwargs)
65
60
  input = args.input
66
61
  other = args.other
@@ -25,6 +25,7 @@ from tico.utils import logging
25
25
  from tico.utils.errors import NotYetSupportedError
26
26
  from tico.utils.passes import PassBase, PassResult
27
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
28
29
  from tico.utils.validate_args_kwargs import (
29
30
  AvgPool2dArgs,
30
31
  Conv2DArgs,
@@ -363,11 +364,9 @@ class LegalizePreDefinedLayoutOperators(PassBase):
363
364
  graph = graph_module.graph
364
365
  modified = False
365
366
  for node in graph.nodes:
366
- if not node.op == "call_function":
367
+ if not is_target_node(node, list(target_to_legalize_func.keys())):
367
368
  continue
368
369
 
369
- if node.target not in target_to_legalize_func:
370
- continue
371
370
  modified |= target_to_legalize_func[node.target](exported_program, node)
372
371
 
373
372
  graph.eliminate_dead_code()
@@ -22,6 +22,8 @@ from torch.export import ExportedProgram
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
26
+ from tico.utils.validate_args_kwargs import PowTensorScalarArgs
25
27
 
26
28
 
27
29
  @trace_graph_diff_on_pass
@@ -42,15 +44,11 @@ class LowerPow2ToMul(PassBase):
42
44
  graph = graph_module.graph
43
45
  modified = False
44
46
  for node in graph.nodes:
45
- if not node.op == "call_function":
47
+ if not is_target_node(node, torch.ops.aten.pow.Tensor_Scalar):
46
48
  continue
47
49
 
48
- if node.target != torch.ops.aten.pow.Tensor_Scalar:
49
- continue
50
-
51
- assert len(node.args) == 2, len(node.args)
52
- in_, exp = node.args
53
- assert isinstance(in_, torch.fx.Node), type(in_)
50
+ args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ in_, exp = args.input, args.exponent
54
52
 
55
53
  if exp != 2:
56
54
  continue
@@ -12,12 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
15
+ from typing import Optional, TYPE_CHECKING
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
19
- from typing import Optional
20
-
21
19
  import torch
22
20
  from torch.export import ExportedProgram
23
21
 
@@ -26,6 +24,7 @@ from tico.utils import logging
26
24
  from tico.utils.errors import NotYetSupportedError
27
25
  from tico.utils.passes import PassBase, PassResult
28
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
29
28
  from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
30
29
 
31
30
 
@@ -194,13 +193,10 @@ class LowerToResizeNearestNeighbor(PassBase):
194
193
  graph_module = exported_program.graph_module
195
194
  graph = graph_module.graph
196
195
  for node in graph.nodes:
197
- if not node.op == "call_function":
198
- continue
199
-
200
- if node.target not in [
201
- torch.ops.aten.index.Tensor,
202
- torch.ops.aten.upsample_nearest2d.vec,
203
- ]:
196
+ if not is_target_node(
197
+ node,
198
+ [torch.ops.aten.index.Tensor, torch.ops.aten.upsample_nearest2d.vec],
199
+ ):
204
200
  continue
205
201
 
206
202
  resize_nearest_neighbor = None
@@ -28,13 +28,12 @@ from torch._export.utils import (
28
28
  from torch.export import ExportedProgram
29
29
 
30
30
  from tico.passes import ops
31
-
32
31
  from tico.serialize.circle_graph import extract_shape
33
32
  from tico.utils import logging
34
-
35
33
  from tico.utils.graph import is_single_value_tensor
36
34
  from tico.utils.passes import PassBase, PassResult
37
35
  from tico.utils.trace_decorators import trace_const_diff_on_pass
36
+ from tico.utils.utils import is_target_node
38
37
  from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
39
38
 
40
39
 
@@ -85,10 +84,7 @@ class LowerSelectCopyToSlice(PassBase):
85
84
  graph = graph_module.graph
86
85
  modified = False
87
86
  for node in graph.nodes:
88
- if not node.op == "call_function":
89
- continue
90
-
91
- if not node.target in ops.aten.select:
87
+ if not is_target_node(node, ops.aten.select):
92
88
  continue
93
89
 
94
90
  args = SelectCopyIntArgs(*node.args, **node.kwargs)
@@ -163,11 +159,9 @@ class LowerIndexSelectToSlice(PassBase):
163
159
  graph = graph_module.graph
164
160
  modified = False
165
161
  for node in graph.nodes:
166
- if not node.op == "call_function":
162
+ if not is_target_node(node, ops.aten.index_select):
167
163
  continue
168
164
 
169
- if not node.target in ops.aten.index_select:
170
- continue
171
165
  args = IndexSelectArgs(*node.args, **node.kwargs)
172
166
  input = args.input
173
167
  dim = args.dim
@@ -18,6 +18,7 @@ from tico.passes import ops
18
18
  from tico.utils import logging
19
19
  from tico.utils.passes import PassBase, PassResult
20
20
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
+ from tico.utils.utils import is_target_node
21
22
  from tico.utils.validate_args_kwargs import CatArgs
22
23
 
23
24
 
@@ -37,10 +38,7 @@ class MergeConsecutiveCat(PassBase):
37
38
  graph = graph_module.graph
38
39
  modified = False
39
40
  for cat in graph.nodes:
40
- if not cat.op == "call_function":
41
- continue
42
-
43
- if not cat.target in ops.aten.cat:
41
+ if not is_target_node(cat, ops.aten.cat):
44
42
  continue
45
43
 
46
44
  args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
tico/passes/remove_nop.py CHANGED
@@ -23,6 +23,7 @@ from tico.passes import ops
23
23
  from tico.utils import logging
24
24
  from tico.utils.passes import PassBase, PassResult
25
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
26
27
 
27
28
 
28
29
  @trace_graph_diff_on_pass
@@ -51,11 +52,9 @@ class RemoveNop(PassBase):
51
52
  graph = graph_module.graph
52
53
  modified = False
53
54
  for node in graph.nodes:
54
- if not node.op == "call_function":
55
+ if not is_target_node(node, RemoveNop.target_ops):
55
56
  continue
56
57
 
57
- if not node.target in RemoveNop.target_ops:
58
- continue
59
58
  # TODO Consider memory format
60
59
  if node.target in ops.aten.clone and "memory_format" in node.kwargs:
61
60
  if node.kwargs["memory_format"] not in [
@@ -17,6 +17,7 @@ from torch.export import ExportedProgram
17
17
 
18
18
  from tico.utils.passes import PassBase, PassResult
19
19
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+ from tico.utils.utils import is_target_node
20
21
 
21
22
 
22
23
  assert_node_targets = [
@@ -39,7 +40,7 @@ class RemoveRedundantAssertionNodes(PassBase):
39
40
  graph = graph_module.graph
40
41
  modified = False
41
42
  for node in graph.nodes:
42
- if node.op == "call_function" and node.target in assert_node_targets:
43
+ if is_target_node(node, assert_node_targets):
43
44
  graph.erase_node(node)
44
45
  modified = True
45
46
 
@@ -24,6 +24,8 @@ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.utils import is_target_node
28
+ from tico.utils.validate_args_kwargs import ExpandArgs
27
29
 
28
30
 
29
31
  @trace_graph_diff_on_pass
@@ -42,17 +44,11 @@ class RemoveRedundantExpand(PassBase):
42
44
  graph = graph_module.graph
43
45
  modified = False
44
46
  for node in graph.nodes:
45
- if not node.op == "call_function":
47
+ if not is_target_node(node, ops.aten.expand):
46
48
  continue
47
49
 
48
- if not node.target in ops.aten.expand:
49
- continue
50
-
51
- assert len(node.args) == 2
52
-
53
- input, size = list(node.args)
54
- assert isinstance(input, torch.fx.Node), type(input)
55
- assert isinstance(size, list), type(size)
50
+ args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ input, size = args.input, args.size
56
52
 
57
53
  input_shape = extract_shape(input)
58
54
  if list(input_shape) != size:
@@ -14,16 +14,18 @@
14
14
 
15
15
  from typing import TYPE_CHECKING
16
16
 
17
+
17
18
  if TYPE_CHECKING:
18
19
  import torch.fx
19
20
  import torch
20
21
  from torch.export import ExportedProgram
21
22
 
22
23
  from tico.passes import ops
23
- from tico.serialize.circle_mapping import extract_shape, extract_stride
24
+ from tico.serialize.circle_mapping import extract_shape
24
25
  from tico.utils import logging
25
26
  from tico.utils.passes import PassBase, PassResult
26
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
27
29
  from tico.utils.validate_args_kwargs import PermuteArgs
28
30
 
29
31
 
@@ -75,16 +77,15 @@ class RemoveRedundantPermutePattern1(PassBase):
75
77
  graph = graph_module.graph
76
78
  modified = False
77
79
  for permute2 in graph.nodes:
78
- if not permute2.op == "call_function":
79
- continue
80
- if not permute2.target in ops.aten.permute:
80
+ if not is_target_node(permute2, ops.aten.permute):
81
81
  continue
82
+
82
83
  if len(permute2.users) != 1:
83
84
  continue
84
85
  permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
85
86
  permute1, permute2_dims = permute2_args.input, permute2_args.dims
86
87
 
87
- if not permute1.target in ops.aten.permute:
88
+ if not is_target_node(permute1, ops.aten.permute):
88
89
  continue
89
90
  if len(permute1.users) != 1:
90
91
  continue
@@ -24,7 +24,7 @@ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
- from tico.utils.utils import broadcastable, set_new_meta_val
27
+ from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
28
28
  from tico.utils.validate_args_kwargs import (
29
29
  AddTensorArgs,
30
30
  PermuteArgs,
@@ -67,12 +67,10 @@ class RemoveRedundantReshapePattern1(PassBase):
67
67
  graph = graph_module.graph
68
68
  modified = False
69
69
  for reshape1 in graph.nodes:
70
- if not reshape1.op == "call_function":
71
- continue
72
-
73
70
  ### first reshape
74
- if not reshape1.target in ops.aten.reshape:
71
+ if not is_target_node(reshape1, ops.aten.reshape):
75
72
  continue
73
+
76
74
  # Assumes that other node do not use ops in the pattern for simplisity.
77
75
  if len(reshape1.users) != 1:
78
76
  continue
@@ -86,7 +84,7 @@ class RemoveRedundantReshapePattern1(PassBase):
86
84
 
87
85
  ### permute
88
86
  permute = next(iter(reshape1.users))
89
- if not permute.target in ops.aten.permute:
87
+ if not is_target_node(permute, ops.aten.permute):
90
88
  continue
91
89
  if len(permute.users) != 1:
92
90
  continue
@@ -98,14 +96,14 @@ class RemoveRedundantReshapePattern1(PassBase):
98
96
 
99
97
  ### mul
100
98
  mul = next(iter(permute.users))
101
- if not mul.target in RemoveRedundantReshapePattern1.mul_ops:
99
+ if not is_target_node(mul, RemoveRedundantReshapePattern1.mul_ops):
102
100
  continue
103
101
  if len(mul.users) != 1:
104
102
  continue
105
103
 
106
104
  ### second reshape
107
105
  reshape2 = next(iter(mul.users))
108
- if not reshape2.target in ops.aten.reshape:
106
+ if not is_target_node(reshape2, ops.aten.reshape):
109
107
  continue
110
108
  if len(reshape2.users) != 1:
111
109
  continue
@@ -153,11 +151,8 @@ class RemoveRedundantReshapePattern2(PassBase):
153
151
  graph = graph_module.graph
154
152
  modified = False
155
153
  for reshape1 in graph.nodes:
156
- if not reshape1.op == "call_function":
157
- continue
158
-
159
154
  ### first reshape
160
- if not reshape1.target in ops.aten.reshape:
155
+ if not is_target_node(reshape1, ops.aten.reshape):
161
156
  continue
162
157
  if len(reshape1.users) != 1:
163
158
  continue
@@ -171,7 +166,7 @@ class RemoveRedundantReshapePattern2(PassBase):
171
166
 
172
167
  ### permute
173
168
  permute = next(iter(reshape1.users))
174
- if not permute.target in ops.aten.permute:
169
+ if not is_target_node(permute, ops.aten.permute):
175
170
  continue
176
171
  if len(permute.users) != 1:
177
172
  continue
@@ -183,7 +178,7 @@ class RemoveRedundantReshapePattern2(PassBase):
183
178
 
184
179
  ### second reshape
185
180
  reshape2 = next(iter(permute.users))
186
- if not reshape2.target in ops.aten.reshape:
181
+ if not is_target_node(reshape2, ops.aten.reshape):
187
182
  continue
188
183
  if len(reshape2.users) != 1:
189
184
  continue
@@ -239,20 +234,14 @@ class RemoveRedundantReshapePattern3(PassBase):
239
234
  graph = graph_module.graph
240
235
  modified = False
241
236
  for reshape_1 in graph.nodes:
242
- assert isinstance(reshape_1, torch.fx.Node), type(reshape_1)
243
237
  # reshape_1
244
- if not reshape_1.op == "call_function":
245
- continue
246
- if not reshape_1.target in ops.aten.reshape:
238
+ if not is_target_node(reshape_1, ops.aten.reshape):
247
239
  continue
248
240
  reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type]
249
- softmax, reshape_1_size = reshape_1_args.input, reshape_1_args.size
241
+ softmax = reshape_1_args.input
250
242
 
251
243
  # softmax
252
- assert isinstance(softmax, torch.fx.Node), type(softmax)
253
- if not softmax.op == "call_function":
254
- continue
255
- if not softmax.target in ops.aten.softmax:
244
+ if not is_target_node(softmax, ops.aten.softmax):
256
245
  continue
257
246
  if softmax.target == torch.ops.aten._softmax.default:
258
247
  softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
@@ -354,10 +343,9 @@ class RemoveRedundantReshapePattern4(PassBase):
354
343
  modified = False
355
344
  for reshape1 in graph.nodes:
356
345
  # reshape_1
357
- if not reshape1.op == "call_function":
358
- continue
359
- if not reshape1.target in ops.aten.reshape:
346
+ if not is_target_node(reshape1, ops.aten.reshape):
360
347
  continue
348
+
361
349
  reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
362
350
  reshape1_input, size = reshape1_args.input, reshape1_args.size
363
351
  assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
@@ -370,12 +358,10 @@ class RemoveRedundantReshapePattern4(PassBase):
370
358
 
371
359
  # reshape_2
372
360
  reshape2 = next(iter(reshape1.users))
373
- if not reshape2.op == "call_function":
361
+ if not is_target_node(reshape2, ops.aten.reshape):
374
362
  continue
375
- if not reshape2.target in ops.aten.reshape:
376
- continue
377
- reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
378
363
 
364
+ reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
379
365
  reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.size
380
366
  assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
381
367
  assert isinstance(reshape2_size, list), type(reshape2_size)
@@ -420,10 +406,7 @@ class RemoveRedundantReshapePattern5(PassBase):
420
406
  modified = False
421
407
 
422
408
  for node in graph.nodes:
423
- if not node.op == "call_function":
424
- continue
425
-
426
- if not node.target in ops.aten.reshape:
409
+ if not is_target_node(node, ops.aten.reshape):
427
410
  continue
428
411
 
429
412
  args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
@@ -19,6 +19,7 @@ from tico.serialize.circle_mapping import extract_shape
19
19
  from tico.utils import logging
20
20
  from tico.utils.passes import PassBase, PassResult
21
21
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
22
+ from tico.utils.utils import is_target_node
22
23
  from tico.utils.validate_args_kwargs import SliceArgs
23
24
 
24
25
 
@@ -38,10 +39,7 @@ class RemoveRedundantSlice(PassBase):
38
39
  graph = graph_module.graph
39
40
  modified = False
40
41
  for node in graph.nodes:
41
- if not node.op == "call_function":
42
- continue
43
-
44
- if not node.target in ops.aten.slice:
42
+ if not is_target_node(node, ops.aten.slice):
45
43
  continue
46
44
 
47
45
  args = SliceArgs(*node.args, **node.kwargs)
@@ -22,6 +22,7 @@ from tico.serialize.circle_mapping import extract_torch_dtype
22
22
  from tico.utils import logging
23
23
  from tico.utils.passes import PassBase, PassResult
24
24
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
25
26
  from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
26
27
 
27
28
 
@@ -41,10 +42,7 @@ class RemoveRedundantToCopy(PassBase):
41
42
  graph = graph_module.graph
42
43
  modified = False
43
44
  for node in graph.nodes:
44
- if not node.op == "call_function":
45
- continue
46
-
47
- if not node.target in ops.aten.to_copy:
45
+ if not is_target_node(node, ops.aten.to_copy):
48
46
  continue
49
47
 
50
48
  args: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
@@ -34,6 +34,7 @@ from tico.utils import logging
34
34
  from tico.utils.graph import add_placeholder, is_single_value_tensor
35
35
  from tico.utils.passes import PassBase, PassResult
36
36
  from tico.utils.trace_decorators import trace_const_diff_on_pass
37
+ from tico.utils.utils import is_target_node
37
38
  from tico.utils.validate_args_kwargs import IndexSelectArgs
38
39
 
39
40
 
@@ -78,10 +79,7 @@ class SegmentIndexSelectConst(PassBase):
78
79
  graph = graph_module.graph
79
80
  modified = False
80
81
  for node in graph.nodes:
81
- if not node.op == "call_function":
82
- continue
83
-
84
- if not node.target in ops.aten.index_select:
82
+ if not is_target_node(node, ops.aten.index_select):
85
83
  continue
86
84
 
87
85
  args = IndexSelectArgs(*node.args, **node.kwargs)