tico 0.1.0.dev250610__py3-none-any.whl → 0.1.0.dev250615__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/cast_aten_where_arg_type.py +65 -62
- tico/passes/cast_mixed_type_args.py +2 -5
- tico/passes/convert_conv1d_to_conv2d.py +3 -4
- tico/passes/convert_repeat_to_expand_copy.py +5 -9
- tico/passes/decompose_addmm.py +41 -48
- tico/passes/decompose_batch_norm.py +97 -99
- tico/passes/decompose_fake_quantize.py +4 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -4
- tico/passes/decompose_group_norm.py +8 -7
- tico/passes/decompose_grouped_conv2d.py +2 -3
- tico/passes/decompose_slice_scatter.py +2 -4
- tico/passes/extract_dtype_kwargs.py +2 -1
- tico/passes/fuse_leading_unsqueeze_reshape.py +6 -9
- tico/passes/fuse_redundant_reshape_to_mean.py +3 -5
- tico/passes/legalize_causal_mask_value.py +2 -7
- tico/passes/legalize_predefined_layout_operators.py +2 -3
- tico/passes/lower_pow2_to_mul.py +5 -7
- tico/passes/lower_to_resize_nearest_neighbor.py +6 -10
- tico/passes/lower_to_slice.py +3 -9
- tico/passes/merge_consecutive_cat.py +2 -4
- tico/passes/remove_nop.py +2 -3
- tico/passes/remove_redundant_assert_nodes.py +2 -1
- tico/passes/remove_redundant_expand.py +5 -9
- tico/passes/remove_redundant_permute.py +6 -5
- tico/passes/remove_redundant_reshape.py +26 -43
- tico/passes/remove_redundant_slice.py +2 -4
- tico/passes/remove_redundant_to_copy.py +2 -4
- tico/passes/segment_index_select.py +2 -4
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_where.py +2 -2
- tico/utils/utils.py +4 -6
- tico/utils/validate_args_kwargs.py +1 -1
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/RECORD +39 -39
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250610.dist-info → tico-0.1.0.dev250615.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
25
25
|
from tico.utils import logging
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
-
from tico.utils.utils import enforce_type
|
28
|
+
from tico.utils.utils import enforce_type, is_target_node
|
29
29
|
|
30
30
|
|
31
31
|
@trace_graph_diff_on_pass
|
@@ -82,9 +82,7 @@ class DecomposeSliceScatter(PassBase):
|
|
82
82
|
modified = False
|
83
83
|
|
84
84
|
for node in graph.nodes:
|
85
|
-
if node.
|
86
|
-
continue
|
87
|
-
if node.target != torch.ops.aten.slice_scatter.default:
|
85
|
+
if not is_target_node(node, torch.ops.aten.slice_scatter.default):
|
88
86
|
continue
|
89
87
|
|
90
88
|
@enforce_type
|
@@ -23,6 +23,7 @@ from torch.utils import _pytree as pytree
|
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.passes import PassBase, PassResult
|
25
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.utils import is_target_node
|
26
27
|
|
27
28
|
|
28
29
|
def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
|
@@ -107,7 +108,7 @@ class ExtractDtypeKwargsPass(PassBase):
|
|
107
108
|
graph: torch.fx.Graph = graph_module.graph
|
108
109
|
modified = False
|
109
110
|
for node in graph.nodes:
|
110
|
-
if not node
|
111
|
+
if not is_target_node(node, list(self.target_ops.keys())):
|
111
112
|
continue
|
112
113
|
if "dtype" not in node.kwargs:
|
113
114
|
continue
|
@@ -22,7 +22,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
22
22
|
from tico.utils import logging
|
23
23
|
from tico.utils.passes import PassBase, PassResult
|
24
24
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
-
from tico.utils.utils import
|
25
|
+
from tico.utils.utils import is_target_node
|
26
26
|
from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
|
27
27
|
|
28
28
|
|
@@ -55,25 +55,22 @@ class FuseLeadingUnsqueezeReshape(PassBase):
|
|
55
55
|
graph = gm.graph
|
56
56
|
modified = False
|
57
57
|
for reshape_back in graph.nodes:
|
58
|
-
if (
|
59
|
-
reshape_back.op != "call_function"
|
60
|
-
or reshape_back.target not in ops.aten.reshape
|
61
|
-
):
|
58
|
+
if not is_target_node(reshape_back, ops.aten.reshape):
|
62
59
|
continue
|
63
60
|
reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
|
64
|
-
permute
|
61
|
+
permute = reshape_back_args.input
|
65
62
|
|
66
|
-
if not
|
63
|
+
if not is_target_node(permute, ops.aten.permute):
|
67
64
|
continue
|
68
65
|
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
69
66
|
reshape_front, permute_dims = permute_args.input, permute_args.dims
|
70
67
|
|
71
|
-
if not
|
68
|
+
if not is_target_node(reshape_front, ops.aten.reshape):
|
72
69
|
continue
|
73
70
|
reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
|
74
71
|
reshape_front_input, reshape_front_size = (
|
75
72
|
reshape_front_args.input,
|
76
|
-
reshape_front_args.
|
73
|
+
reshape_front_args.shape,
|
77
74
|
)
|
78
75
|
|
79
76
|
# ---- condition: only leading unsqueeze ------------------------
|
@@ -25,6 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
25
25
|
from tico.utils import logging
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
28
29
|
|
29
30
|
|
30
31
|
@trace_graph_diff_on_pass
|
@@ -45,10 +46,7 @@ class FuseRedundantReshapeToMean(PassBase):
|
|
45
46
|
graph = graph_module.graph
|
46
47
|
modified = False
|
47
48
|
for node in graph.nodes:
|
48
|
-
if not node.
|
49
|
-
continue
|
50
|
-
|
51
|
-
if node.target != torch.ops.aten.mean.dim:
|
49
|
+
if not is_target_node(node, torch.ops.aten.mean.dim):
|
52
50
|
continue
|
53
51
|
|
54
52
|
# If mean is being used in other nodes, do not fuse it.
|
@@ -56,7 +54,7 @@ class FuseRedundantReshapeToMean(PassBase):
|
|
56
54
|
continue
|
57
55
|
|
58
56
|
user_node = next(iter(node.users))
|
59
|
-
if
|
57
|
+
if not is_target_node(user_node, ops.aten.reshape):
|
60
58
|
continue
|
61
59
|
|
62
60
|
mean_args, mean_kwargs = pytree.tree_map_only(
|
@@ -20,10 +20,10 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.passes import ops
|
23
|
-
|
24
23
|
from tico.utils import logging
|
25
24
|
from tico.utils.passes import PassBase, PassResult
|
26
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.utils import is_target_node
|
27
27
|
from tico.utils.validate_args_kwargs import AddTensorArgs
|
28
28
|
|
29
29
|
|
@@ -53,14 +53,9 @@ class LegalizeCausalMaskValue(PassBase):
|
|
53
53
|
graph = graph_module.graph
|
54
54
|
modified = False
|
55
55
|
for node in graph.nodes:
|
56
|
-
if not node.
|
57
|
-
continue
|
58
|
-
|
59
|
-
if not node.target in ops.aten.add:
|
56
|
+
if not is_target_node(node, ops.aten.add):
|
60
57
|
continue
|
61
58
|
|
62
|
-
assert len(node.args) == 2
|
63
|
-
|
64
59
|
args = AddTensorArgs(*node.args, **node.kwargs)
|
65
60
|
input = args.input
|
66
61
|
other = args.other
|
@@ -25,6 +25,7 @@ from tico.utils import logging
|
|
25
25
|
from tico.utils.errors import NotYetSupportedError
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
28
29
|
from tico.utils.validate_args_kwargs import (
|
29
30
|
AvgPool2dArgs,
|
30
31
|
Conv2DArgs,
|
@@ -363,11 +364,9 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
363
364
|
graph = graph_module.graph
|
364
365
|
modified = False
|
365
366
|
for node in graph.nodes:
|
366
|
-
if not node.
|
367
|
+
if not is_target_node(node, list(target_to_legalize_func.keys())):
|
367
368
|
continue
|
368
369
|
|
369
|
-
if node.target not in target_to_legalize_func:
|
370
|
-
continue
|
371
370
|
modified |= target_to_legalize_func[node.target](exported_program, node)
|
372
371
|
|
373
372
|
graph.eliminate_dead_code()
|
tico/passes/lower_pow2_to_mul.py
CHANGED
@@ -22,6 +22,8 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.utils import logging
|
23
23
|
from tico.utils.passes import PassBase, PassResult
|
24
24
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
from tico.utils.utils import is_target_node
|
26
|
+
from tico.utils.validate_args_kwargs import PowTensorScalarArgs
|
25
27
|
|
26
28
|
|
27
29
|
@trace_graph_diff_on_pass
|
@@ -42,15 +44,11 @@ class LowerPow2ToMul(PassBase):
|
|
42
44
|
graph = graph_module.graph
|
43
45
|
modified = False
|
44
46
|
for node in graph.nodes:
|
45
|
-
if not node.
|
47
|
+
if not is_target_node(node, torch.ops.aten.pow.Tensor_Scalar):
|
46
48
|
continue
|
47
49
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
assert len(node.args) == 2, len(node.args)
|
52
|
-
in_, exp = node.args
|
53
|
-
assert isinstance(in_, torch.fx.Node), type(in_)
|
50
|
+
args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
in_, exp = args.input, args.exponent
|
54
52
|
|
55
53
|
if exp != 2:
|
56
54
|
continue
|
@@ -12,12 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import TYPE_CHECKING
|
15
|
+
from typing import Optional, TYPE_CHECKING
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
import torch.fx
|
19
|
-
from typing import Optional
|
20
|
-
|
21
19
|
import torch
|
22
20
|
from torch.export import ExportedProgram
|
23
21
|
|
@@ -26,6 +24,7 @@ from tico.utils import logging
|
|
26
24
|
from tico.utils.errors import NotYetSupportedError
|
27
25
|
from tico.utils.passes import PassBase, PassResult
|
28
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import is_target_node
|
29
28
|
from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
|
30
29
|
|
31
30
|
|
@@ -194,13 +193,10 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
194
193
|
graph_module = exported_program.graph_module
|
195
194
|
graph = graph_module.graph
|
196
195
|
for node in graph.nodes:
|
197
|
-
if not
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
torch.ops.aten.index.Tensor,
|
202
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
203
|
-
]:
|
196
|
+
if not is_target_node(
|
197
|
+
node,
|
198
|
+
[torch.ops.aten.index.Tensor, torch.ops.aten.upsample_nearest2d.vec],
|
199
|
+
):
|
204
200
|
continue
|
205
201
|
|
206
202
|
resize_nearest_neighbor = None
|
tico/passes/lower_to_slice.py
CHANGED
@@ -28,13 +28,12 @@ from torch._export.utils import (
|
|
28
28
|
from torch.export import ExportedProgram
|
29
29
|
|
30
30
|
from tico.passes import ops
|
31
|
-
|
32
31
|
from tico.serialize.circle_graph import extract_shape
|
33
32
|
from tico.utils import logging
|
34
|
-
|
35
33
|
from tico.utils.graph import is_single_value_tensor
|
36
34
|
from tico.utils.passes import PassBase, PassResult
|
37
35
|
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
36
|
+
from tico.utils.utils import is_target_node
|
38
37
|
from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
|
39
38
|
|
40
39
|
|
@@ -85,10 +84,7 @@ class LowerSelectCopyToSlice(PassBase):
|
|
85
84
|
graph = graph_module.graph
|
86
85
|
modified = False
|
87
86
|
for node in graph.nodes:
|
88
|
-
if not node.
|
89
|
-
continue
|
90
|
-
|
91
|
-
if not node.target in ops.aten.select:
|
87
|
+
if not is_target_node(node, ops.aten.select):
|
92
88
|
continue
|
93
89
|
|
94
90
|
args = SelectCopyIntArgs(*node.args, **node.kwargs)
|
@@ -163,11 +159,9 @@ class LowerIndexSelectToSlice(PassBase):
|
|
163
159
|
graph = graph_module.graph
|
164
160
|
modified = False
|
165
161
|
for node in graph.nodes:
|
166
|
-
if not node.
|
162
|
+
if not is_target_node(node, ops.aten.index_select):
|
167
163
|
continue
|
168
164
|
|
169
|
-
if not node.target in ops.aten.index_select:
|
170
|
-
continue
|
171
165
|
args = IndexSelectArgs(*node.args, **node.kwargs)
|
172
166
|
input = args.input
|
173
167
|
dim = args.dim
|
@@ -18,6 +18,7 @@ from tico.passes import ops
|
|
18
18
|
from tico.utils import logging
|
19
19
|
from tico.utils.passes import PassBase, PassResult
|
20
20
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
21
|
+
from tico.utils.utils import is_target_node
|
21
22
|
from tico.utils.validate_args_kwargs import CatArgs
|
22
23
|
|
23
24
|
|
@@ -37,10 +38,7 @@ class MergeConsecutiveCat(PassBase):
|
|
37
38
|
graph = graph_module.graph
|
38
39
|
modified = False
|
39
40
|
for cat in graph.nodes:
|
40
|
-
if not cat.
|
41
|
-
continue
|
42
|
-
|
43
|
-
if not cat.target in ops.aten.cat:
|
41
|
+
if not is_target_node(cat, ops.aten.cat):
|
44
42
|
continue
|
45
43
|
|
46
44
|
args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
|
tico/passes/remove_nop.py
CHANGED
@@ -23,6 +23,7 @@ from tico.passes import ops
|
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.passes import PassBase, PassResult
|
25
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.utils import is_target_node
|
26
27
|
|
27
28
|
|
28
29
|
@trace_graph_diff_on_pass
|
@@ -51,11 +52,9 @@ class RemoveNop(PassBase):
|
|
51
52
|
graph = graph_module.graph
|
52
53
|
modified = False
|
53
54
|
for node in graph.nodes:
|
54
|
-
if not node.
|
55
|
+
if not is_target_node(node, RemoveNop.target_ops):
|
55
56
|
continue
|
56
57
|
|
57
|
-
if not node.target in RemoveNop.target_ops:
|
58
|
-
continue
|
59
58
|
# TODO Consider memory format
|
60
59
|
if node.target in ops.aten.clone and "memory_format" in node.kwargs:
|
61
60
|
if node.kwargs["memory_format"] not in [
|
@@ -17,6 +17,7 @@ from torch.export import ExportedProgram
|
|
17
17
|
|
18
18
|
from tico.utils.passes import PassBase, PassResult
|
19
19
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
20
|
+
from tico.utils.utils import is_target_node
|
20
21
|
|
21
22
|
|
22
23
|
assert_node_targets = [
|
@@ -39,7 +40,7 @@ class RemoveRedundantAssertionNodes(PassBase):
|
|
39
40
|
graph = graph_module.graph
|
40
41
|
modified = False
|
41
42
|
for node in graph.nodes:
|
42
|
-
if node
|
43
|
+
if is_target_node(node, assert_node_targets):
|
43
44
|
graph.erase_node(node)
|
44
45
|
modified = True
|
45
46
|
|
@@ -24,6 +24,8 @@ from tico.serialize.circle_mapping import extract_shape
|
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import is_target_node
|
28
|
+
from tico.utils.validate_args_kwargs import ExpandArgs
|
27
29
|
|
28
30
|
|
29
31
|
@trace_graph_diff_on_pass
|
@@ -42,17 +44,11 @@ class RemoveRedundantExpand(PassBase):
|
|
42
44
|
graph = graph_module.graph
|
43
45
|
modified = False
|
44
46
|
for node in graph.nodes:
|
45
|
-
if not node.
|
47
|
+
if not is_target_node(node, ops.aten.expand):
|
46
48
|
continue
|
47
49
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
assert len(node.args) == 2
|
52
|
-
|
53
|
-
input, size = list(node.args)
|
54
|
-
assert isinstance(input, torch.fx.Node), type(input)
|
55
|
-
assert isinstance(size, list), type(size)
|
50
|
+
args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
input, size = args.input, args.size
|
56
52
|
|
57
53
|
input_shape = extract_shape(input)
|
58
54
|
if list(input_shape) != size:
|
@@ -14,16 +14,18 @@
|
|
14
14
|
|
15
15
|
from typing import TYPE_CHECKING
|
16
16
|
|
17
|
+
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
import torch.fx
|
19
20
|
import torch
|
20
21
|
from torch.export import ExportedProgram
|
21
22
|
|
22
23
|
from tico.passes import ops
|
23
|
-
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
25
|
from tico.utils import logging
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
27
29
|
from tico.utils.validate_args_kwargs import PermuteArgs
|
28
30
|
|
29
31
|
|
@@ -75,16 +77,15 @@ class RemoveRedundantPermutePattern1(PassBase):
|
|
75
77
|
graph = graph_module.graph
|
76
78
|
modified = False
|
77
79
|
for permute2 in graph.nodes:
|
78
|
-
if not permute2.
|
79
|
-
continue
|
80
|
-
if not permute2.target in ops.aten.permute:
|
80
|
+
if not is_target_node(permute2, ops.aten.permute):
|
81
81
|
continue
|
82
|
+
|
82
83
|
if len(permute2.users) != 1:
|
83
84
|
continue
|
84
85
|
permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
|
85
86
|
permute1, permute2_dims = permute2_args.input, permute2_args.dims
|
86
87
|
|
87
|
-
if not permute1
|
88
|
+
if not is_target_node(permute1, ops.aten.permute):
|
88
89
|
continue
|
89
90
|
if len(permute1.users) != 1:
|
90
91
|
continue
|
@@ -24,7 +24,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
-
from tico.utils.utils import broadcastable, set_new_meta_val
|
27
|
+
from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
|
28
28
|
from tico.utils.validate_args_kwargs import (
|
29
29
|
AddTensorArgs,
|
30
30
|
PermuteArgs,
|
@@ -67,17 +67,15 @@ class RemoveRedundantReshapePattern1(PassBase):
|
|
67
67
|
graph = graph_module.graph
|
68
68
|
modified = False
|
69
69
|
for reshape1 in graph.nodes:
|
70
|
-
if not reshape1.op == "call_function":
|
71
|
-
continue
|
72
|
-
|
73
70
|
### first reshape
|
74
|
-
if not reshape1
|
71
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
75
72
|
continue
|
73
|
+
|
76
74
|
# Assumes that other node do not use ops in the pattern for simplisity.
|
77
75
|
if len(reshape1.users) != 1:
|
78
76
|
continue
|
79
77
|
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
80
|
-
reshape1_input
|
78
|
+
reshape1_input = reshape1_args.input
|
81
79
|
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
82
80
|
if [1] + list(extract_shape(reshape1_input)) != list(
|
83
81
|
extract_shape(reshape1)
|
@@ -86,7 +84,7 @@ class RemoveRedundantReshapePattern1(PassBase):
|
|
86
84
|
|
87
85
|
### permute
|
88
86
|
permute = next(iter(reshape1.users))
|
89
|
-
if not permute
|
87
|
+
if not is_target_node(permute, ops.aten.permute):
|
90
88
|
continue
|
91
89
|
if len(permute.users) != 1:
|
92
90
|
continue
|
@@ -98,19 +96,19 @@ class RemoveRedundantReshapePattern1(PassBase):
|
|
98
96
|
|
99
97
|
### mul
|
100
98
|
mul = next(iter(permute.users))
|
101
|
-
if not mul
|
99
|
+
if not is_target_node(mul, RemoveRedundantReshapePattern1.mul_ops):
|
102
100
|
continue
|
103
101
|
if len(mul.users) != 1:
|
104
102
|
continue
|
105
103
|
|
106
104
|
### second reshape
|
107
105
|
reshape2 = next(iter(mul.users))
|
108
|
-
if not reshape2
|
106
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
109
107
|
continue
|
110
108
|
if len(reshape2.users) != 1:
|
111
109
|
continue
|
112
110
|
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
113
|
-
reshape2_input
|
111
|
+
reshape2_input = reshape2_args.input
|
114
112
|
# (1xAxCxB) - `aten.reshape - (AxCxB)
|
115
113
|
if list(extract_shape(reshape2_input)) != [1] + list(
|
116
114
|
extract_shape(reshape2)
|
@@ -153,16 +151,13 @@ class RemoveRedundantReshapePattern2(PassBase):
|
|
153
151
|
graph = graph_module.graph
|
154
152
|
modified = False
|
155
153
|
for reshape1 in graph.nodes:
|
156
|
-
if not reshape1.op == "call_function":
|
157
|
-
continue
|
158
|
-
|
159
154
|
### first reshape
|
160
|
-
if not reshape1
|
155
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
161
156
|
continue
|
162
157
|
if len(reshape1.users) != 1:
|
163
158
|
continue
|
164
159
|
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
165
|
-
reshape1_input
|
160
|
+
reshape1_input = reshape1_args.input
|
166
161
|
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
167
162
|
if [1] + list(extract_shape(reshape1_input)) != list(
|
168
163
|
extract_shape(reshape1)
|
@@ -171,7 +166,7 @@ class RemoveRedundantReshapePattern2(PassBase):
|
|
171
166
|
|
172
167
|
### permute
|
173
168
|
permute = next(iter(reshape1.users))
|
174
|
-
if not permute
|
169
|
+
if not is_target_node(permute, ops.aten.permute):
|
175
170
|
continue
|
176
171
|
if len(permute.users) != 1:
|
177
172
|
continue
|
@@ -183,12 +178,12 @@ class RemoveRedundantReshapePattern2(PassBase):
|
|
183
178
|
|
184
179
|
### second reshape
|
185
180
|
reshape2 = next(iter(permute.users))
|
186
|
-
if not reshape2
|
181
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
187
182
|
continue
|
188
183
|
if len(reshape2.users) != 1:
|
189
184
|
continue
|
190
185
|
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
191
|
-
reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.
|
186
|
+
reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
|
192
187
|
# (Bx1xAxC) - `aten.reshape - (Bx(A*C))
|
193
188
|
reshape2_input_shape = list(extract_shape(reshape2_input))
|
194
189
|
assert len(reshape2_input_shape) == 4
|
@@ -239,20 +234,14 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
239
234
|
graph = graph_module.graph
|
240
235
|
modified = False
|
241
236
|
for reshape_1 in graph.nodes:
|
242
|
-
assert isinstance(reshape_1, torch.fx.Node), type(reshape_1)
|
243
237
|
# reshape_1
|
244
|
-
if not reshape_1.
|
245
|
-
continue
|
246
|
-
if not reshape_1.target in ops.aten.reshape:
|
238
|
+
if not is_target_node(reshape_1, ops.aten.reshape):
|
247
239
|
continue
|
248
240
|
reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type]
|
249
|
-
softmax
|
241
|
+
softmax = reshape_1_args.input
|
250
242
|
|
251
243
|
# softmax
|
252
|
-
|
253
|
-
if not softmax.op == "call_function":
|
254
|
-
continue
|
255
|
-
if not softmax.target in ops.aten.softmax:
|
244
|
+
if not is_target_node(softmax, ops.aten.softmax):
|
256
245
|
continue
|
257
246
|
if softmax.target == torch.ops.aten._softmax.default:
|
258
247
|
softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
|
@@ -281,7 +270,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
281
270
|
if not reshape_2.target in ops.aten.reshape:
|
282
271
|
continue
|
283
272
|
reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
|
284
|
-
reshape_2_input
|
273
|
+
reshape_2_input = reshape_2_args.input
|
285
274
|
assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input)
|
286
275
|
# reshape_3
|
287
276
|
if not reshape_3.op == "call_function":
|
@@ -289,7 +278,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
289
278
|
if not reshape_3.target in ops.aten.reshape:
|
290
279
|
continue
|
291
280
|
reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
|
292
|
-
reshape_3_input
|
281
|
+
reshape_3_input = reshape_3_args.input
|
293
282
|
assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input)
|
294
283
|
|
295
284
|
# Check condition
|
@@ -354,12 +343,11 @@ class RemoveRedundantReshapePattern4(PassBase):
|
|
354
343
|
modified = False
|
355
344
|
for reshape1 in graph.nodes:
|
356
345
|
# reshape_1
|
357
|
-
if not reshape1.
|
358
|
-
continue
|
359
|
-
if not reshape1.target in ops.aten.reshape:
|
346
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
360
347
|
continue
|
348
|
+
|
361
349
|
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
362
|
-
reshape1_input, size = reshape1_args.input, reshape1_args.
|
350
|
+
reshape1_input, size = reshape1_args.input, reshape1_args.shape
|
363
351
|
assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
|
364
352
|
assert isinstance(size, list), type(size)
|
365
353
|
for s in size:
|
@@ -370,13 +358,11 @@ class RemoveRedundantReshapePattern4(PassBase):
|
|
370
358
|
|
371
359
|
# reshape_2
|
372
360
|
reshape2 = next(iter(reshape1.users))
|
373
|
-
if not reshape2.
|
361
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
374
362
|
continue
|
375
|
-
if not reshape2.target in ops.aten.reshape:
|
376
|
-
continue
|
377
|
-
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
378
363
|
|
379
|
-
|
364
|
+
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
365
|
+
reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
|
380
366
|
assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
|
381
367
|
assert isinstance(reshape2_size, list), type(reshape2_size)
|
382
368
|
for s in reshape2_size:
|
@@ -420,14 +406,11 @@ class RemoveRedundantReshapePattern5(PassBase):
|
|
420
406
|
modified = False
|
421
407
|
|
422
408
|
for node in graph.nodes:
|
423
|
-
if not node.
|
424
|
-
continue
|
425
|
-
|
426
|
-
if not node.target in ops.aten.reshape:
|
409
|
+
if not is_target_node(node, ops.aten.reshape):
|
427
410
|
continue
|
428
411
|
|
429
412
|
args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
430
|
-
output_shape = args.
|
413
|
+
output_shape = args.shape
|
431
414
|
input_shape = list(extract_shape(args.input))
|
432
415
|
|
433
416
|
if output_shape != input_shape:
|
@@ -19,6 +19,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
19
19
|
from tico.utils import logging
|
20
20
|
from tico.utils.passes import PassBase, PassResult
|
21
21
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
22
|
+
from tico.utils.utils import is_target_node
|
22
23
|
from tico.utils.validate_args_kwargs import SliceArgs
|
23
24
|
|
24
25
|
|
@@ -38,10 +39,7 @@ class RemoveRedundantSlice(PassBase):
|
|
38
39
|
graph = graph_module.graph
|
39
40
|
modified = False
|
40
41
|
for node in graph.nodes:
|
41
|
-
if not node.
|
42
|
-
continue
|
43
|
-
|
44
|
-
if not node.target in ops.aten.slice:
|
42
|
+
if not is_target_node(node, ops.aten.slice):
|
45
43
|
continue
|
46
44
|
|
47
45
|
args = SliceArgs(*node.args, **node.kwargs)
|
@@ -22,6 +22,7 @@ from tico.serialize.circle_mapping import extract_torch_dtype
|
|
22
22
|
from tico.utils import logging
|
23
23
|
from tico.utils.passes import PassBase, PassResult
|
24
24
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
from tico.utils.utils import is_target_node
|
25
26
|
from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
|
26
27
|
|
27
28
|
|
@@ -41,10 +42,7 @@ class RemoveRedundantToCopy(PassBase):
|
|
41
42
|
graph = graph_module.graph
|
42
43
|
modified = False
|
43
44
|
for node in graph.nodes:
|
44
|
-
if not node.
|
45
|
-
continue
|
46
|
-
|
47
|
-
if not node.target in ops.aten.to_copy:
|
45
|
+
if not is_target_node(node, ops.aten.to_copy):
|
48
46
|
continue
|
49
47
|
|
50
48
|
args: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
|
@@ -34,6 +34,7 @@ from tico.utils import logging
|
|
34
34
|
from tico.utils.graph import add_placeholder, is_single_value_tensor
|
35
35
|
from tico.utils.passes import PassBase, PassResult
|
36
36
|
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
37
|
+
from tico.utils.utils import is_target_node
|
37
38
|
from tico.utils.validate_args_kwargs import IndexSelectArgs
|
38
39
|
|
39
40
|
|
@@ -78,10 +79,7 @@ class SegmentIndexSelectConst(PassBase):
|
|
78
79
|
graph = graph_module.graph
|
79
80
|
modified = False
|
80
81
|
for node in graph.nodes:
|
81
|
-
if not node.
|
82
|
-
continue
|
83
|
-
|
84
|
-
if not node.target in ops.aten.index_select:
|
82
|
+
if not is_target_node(node, ops.aten.index_select):
|
85
83
|
continue
|
86
84
|
|
87
85
|
args = IndexSelectArgs(*node.args, **node.kwargs)
|
@@ -48,7 +48,7 @@ class ReshapeVisitor(NodeVisitor):
|
|
48
48
|
)
|
49
49
|
args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
50
50
|
input = args.input
|
51
|
-
size = args.
|
51
|
+
size = args.shape
|
52
52
|
|
53
53
|
if isinstance(size, int):
|
54
54
|
raise NotYetSupportedError("scalar size conversion is not supported yet.")
|