tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
|
28
|
+
|
29
|
+
@trace_graph_diff_on_pass
|
30
|
+
class RemoveRedundantExpand(PassBase):
|
31
|
+
"""
|
32
|
+
This pass removes redundant `aten.expand` operators where shapes of input and output are same.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
|
38
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
graph_module = exported_program.graph_module
|
42
|
+
graph = graph_module.graph
|
43
|
+
modified = False
|
44
|
+
for node in graph.nodes:
|
45
|
+
if not node.op == "call_function":
|
46
|
+
continue
|
47
|
+
|
48
|
+
if not node.target in ops.aten.expand:
|
49
|
+
continue
|
50
|
+
|
51
|
+
assert len(node.args) == 2
|
52
|
+
|
53
|
+
input, size = list(node.args)
|
54
|
+
assert isinstance(input, torch.fx.Node), type(input)
|
55
|
+
assert isinstance(size, list), type(size)
|
56
|
+
|
57
|
+
input_shape = extract_shape(input)
|
58
|
+
if list(input_shape) != size:
|
59
|
+
continue
|
60
|
+
|
61
|
+
node.replace_all_uses_with(input, propagate_meta=False)
|
62
|
+
|
63
|
+
modified = True
|
64
|
+
logger.debug(f"{node.name} is replaced with {input.name}")
|
65
|
+
|
66
|
+
graph.eliminate_dead_code()
|
67
|
+
graph.lint()
|
68
|
+
graph_module.recompile()
|
69
|
+
|
70
|
+
return PassResult(modified)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape, extract_stride
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
|
28
|
+
|
29
|
+
def passes():
|
30
|
+
"""
|
31
|
+
Return a list of passes that remove redundant `aten.permute` operators.
|
32
|
+
|
33
|
+
NOTE Both shape and stride of input/output should be same.
|
34
|
+
"""
|
35
|
+
return [
|
36
|
+
RemoveRedundantPermutePattern1(),
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
@trace_graph_diff_on_pass
|
41
|
+
class RemoveRedundantPermutePattern1(PassBase):
|
42
|
+
def __init__(self):
|
43
|
+
super().__init__()
|
44
|
+
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
46
|
+
"""
|
47
|
+
[BEFORE]
|
48
|
+
(AxBxC) - aten.permute - aten.permute - (AxBxC)
|
49
|
+
[AFTER]
|
50
|
+
(AxBxC)
|
51
|
+
"""
|
52
|
+
logger = logging.getLogger(__name__)
|
53
|
+
|
54
|
+
graph_module = exported_program.graph_module
|
55
|
+
graph = graph_module.graph
|
56
|
+
modified = False
|
57
|
+
for permute2 in graph.nodes:
|
58
|
+
if not permute2.op == "call_function":
|
59
|
+
continue
|
60
|
+
if not permute2.target in ops.aten.permute:
|
61
|
+
continue
|
62
|
+
if len(permute2.users) != 1:
|
63
|
+
continue
|
64
|
+
assert len(permute2.args) == 2
|
65
|
+
permute1, permute2_dims = permute2.args
|
66
|
+
assert isinstance(permute1, torch.fx.Node), type(permute1)
|
67
|
+
assert isinstance(permute2_dims, list), type(permute2_dims)
|
68
|
+
for dim in permute2_dims:
|
69
|
+
assert isinstance(dim, int), type(dim)
|
70
|
+
|
71
|
+
if not permute1.target in ops.aten.permute:
|
72
|
+
continue
|
73
|
+
if len(permute1.users) != 1:
|
74
|
+
continue
|
75
|
+
assert len(permute1.args) == 2
|
76
|
+
permute1_input, permute1_dims = permute1.args
|
77
|
+
assert isinstance(permute1_input, torch.fx.Node), type(permute1_input)
|
78
|
+
assert isinstance(permute1_dims, list), type(permute1_dims)
|
79
|
+
for dim in permute1_dims:
|
80
|
+
assert isinstance(dim, int), type(dim)
|
81
|
+
|
82
|
+
# shape
|
83
|
+
permute1_input_shape = extract_shape(permute1_input)
|
84
|
+
permute2_shape = extract_shape(permute2)
|
85
|
+
if permute1_input_shape != permute2_shape:
|
86
|
+
continue
|
87
|
+
# stride
|
88
|
+
permute1_input_stride = extract_stride(permute1_input)
|
89
|
+
permute2_stride = extract_stride(permute2)
|
90
|
+
if permute1_input_stride != permute2_stride:
|
91
|
+
continue
|
92
|
+
|
93
|
+
permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
|
94
|
+
|
95
|
+
modified = True
|
96
|
+
logger.debug(f"{permute1.name} and {permute2.name} are removed.")
|
97
|
+
|
98
|
+
graph.eliminate_dead_code()
|
99
|
+
graph.lint()
|
100
|
+
graph_module.recompile()
|
101
|
+
|
102
|
+
return PassResult(modified)
|
@@ -0,0 +1,431 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import set_new_meta_val
|
28
|
+
from tico.utils.validate_args_kwargs import ReshapeArgs
|
29
|
+
|
30
|
+
|
31
|
+
def passes():
|
32
|
+
"""
|
33
|
+
Return list of passes that remove redundant `aten.reshape` operators.
|
34
|
+
"""
|
35
|
+
return [
|
36
|
+
RemoveRedundantReshapePattern1(),
|
37
|
+
RemoveRedundantReshapePattern2(),
|
38
|
+
RemoveRedundantReshapePattern3(),
|
39
|
+
RemoveRedundantReshapePattern4(),
|
40
|
+
RemoveRedundantReshapePattern5(),
|
41
|
+
]
|
42
|
+
|
43
|
+
|
44
|
+
@trace_graph_diff_on_pass
|
45
|
+
class RemoveRedundantReshapePattern1(PassBase):
|
46
|
+
mul_ops: List[torch._ops.OpOverload] = ops.aten.mul_scalar + ops.aten.mul_tensor
|
47
|
+
|
48
|
+
def __init__(self):
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
52
|
+
"""
|
53
|
+
[BEFORE]
|
54
|
+
`(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)`
|
55
|
+
[AFTER]
|
56
|
+
`(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)`
|
57
|
+
"""
|
58
|
+
logger = logging.getLogger(__name__)
|
59
|
+
|
60
|
+
graph_module = exported_program.graph_module
|
61
|
+
graph = graph_module.graph
|
62
|
+
modified = False
|
63
|
+
for reshape1 in graph.nodes:
|
64
|
+
if not reshape1.op == "call_function":
|
65
|
+
continue
|
66
|
+
|
67
|
+
### first reshape
|
68
|
+
if not reshape1.target in ops.aten.reshape:
|
69
|
+
continue
|
70
|
+
# Assumes that other node do not use ops in the pattern for simplisity.
|
71
|
+
if len(reshape1.users) != 1:
|
72
|
+
continue
|
73
|
+
assert len(reshape1.args) == 2, len(reshape1.args)
|
74
|
+
reshape1_input, reshape1_size = reshape1.args
|
75
|
+
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
76
|
+
if [1] + list(extract_shape(reshape1_input)) != list(
|
77
|
+
extract_shape(reshape1)
|
78
|
+
):
|
79
|
+
continue
|
80
|
+
|
81
|
+
### permute
|
82
|
+
permute = next(iter(reshape1.users))
|
83
|
+
if not permute.target in ops.aten.permute:
|
84
|
+
continue
|
85
|
+
if len(permute.users) != 1:
|
86
|
+
continue
|
87
|
+
assert len(permute.args) == 2, len(permute.args)
|
88
|
+
permute_input, permute_dims = permute.args
|
89
|
+
# (1xAxBxC) - `aten.permute` - (1xAxCxB)
|
90
|
+
if permute_dims != [0, 1, 3, 2]:
|
91
|
+
continue
|
92
|
+
|
93
|
+
### mul
|
94
|
+
mul = next(iter(permute.users))
|
95
|
+
if not mul.target in RemoveRedundantReshapePattern1.mul_ops:
|
96
|
+
continue
|
97
|
+
if len(mul.users) != 1:
|
98
|
+
continue
|
99
|
+
|
100
|
+
### second reshape
|
101
|
+
reshape2 = next(iter(mul.users))
|
102
|
+
if not reshape2.target in ops.aten.reshape:
|
103
|
+
continue
|
104
|
+
if len(reshape2.users) != 1:
|
105
|
+
continue
|
106
|
+
reshape2_input, reshape2_size = reshape2.args
|
107
|
+
# (1xAxCxB) - `aten.reshape - (AxCxB)
|
108
|
+
if list(extract_shape(reshape2_input)) != [1] + list(
|
109
|
+
extract_shape(reshape2)
|
110
|
+
):
|
111
|
+
continue
|
112
|
+
|
113
|
+
### remove redundant reshapes
|
114
|
+
# update permute (remove reshape1)
|
115
|
+
permute.args = (reshape1_input, [0, 2, 1])
|
116
|
+
set_new_meta_val(permute)
|
117
|
+
set_new_meta_val(mul)
|
118
|
+
# remove reshape2
|
119
|
+
reshape2.replace_all_uses_with(mul, propagate_meta=False)
|
120
|
+
|
121
|
+
modified = True
|
122
|
+
logger.debug(f"{reshape1.name} and {reshape2.name} are removed.")
|
123
|
+
|
124
|
+
graph.eliminate_dead_code()
|
125
|
+
graph.lint()
|
126
|
+
graph_module.recompile()
|
127
|
+
|
128
|
+
return PassResult(modified)
|
129
|
+
|
130
|
+
|
131
|
+
@trace_graph_diff_on_pass
|
132
|
+
class RemoveRedundantReshapePattern2(PassBase):
|
133
|
+
def __init__(self):
|
134
|
+
super().__init__()
|
135
|
+
|
136
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
137
|
+
"""
|
138
|
+
[BEFORE]
|
139
|
+
`(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))`
|
140
|
+
[AFTER]
|
141
|
+
`(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))`
|
142
|
+
"""
|
143
|
+
logger = logging.getLogger(__name__)
|
144
|
+
|
145
|
+
graph_module = exported_program.graph_module
|
146
|
+
graph = graph_module.graph
|
147
|
+
modified = False
|
148
|
+
for reshape1 in graph.nodes:
|
149
|
+
if not reshape1.op == "call_function":
|
150
|
+
continue
|
151
|
+
|
152
|
+
### first reshape
|
153
|
+
if not reshape1.target in ops.aten.reshape:
|
154
|
+
continue
|
155
|
+
if len(reshape1.users) != 1:
|
156
|
+
continue
|
157
|
+
assert len(reshape1.args) == 2, len(reshape1.args)
|
158
|
+
reshape1_input, reshape1_size = reshape1.args
|
159
|
+
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
160
|
+
if [1] + list(extract_shape(reshape1_input)) != list(
|
161
|
+
extract_shape(reshape1)
|
162
|
+
):
|
163
|
+
continue
|
164
|
+
|
165
|
+
### permute
|
166
|
+
permute = next(iter(reshape1.users))
|
167
|
+
if not permute.target in ops.aten.permute:
|
168
|
+
continue
|
169
|
+
if len(permute.users) != 1:
|
170
|
+
continue
|
171
|
+
assert len(permute.args) == 2, len(permute.args)
|
172
|
+
permute_input, permute_dims = permute.args
|
173
|
+
# (1xAxBxC) - `aten.permute` - (Bx1xAxC)
|
174
|
+
if permute_dims != [2, 0, 1, 3]:
|
175
|
+
continue
|
176
|
+
|
177
|
+
### second reshape
|
178
|
+
reshape2 = next(iter(permute.users))
|
179
|
+
if not reshape2.target in ops.aten.reshape:
|
180
|
+
continue
|
181
|
+
if len(reshape2.users) != 1:
|
182
|
+
continue
|
183
|
+
reshape2_input, reshape2_size = reshape2.args
|
184
|
+
# (Bx1xAxC) - `aten.reshape - (Bx(A*C))
|
185
|
+
reshape2_input_shape = list(extract_shape(reshape2_input))
|
186
|
+
assert len(reshape2_input_shape) == 4
|
187
|
+
if list(extract_shape(reshape2)) != [
|
188
|
+
reshape2_input_shape[0],
|
189
|
+
(reshape2_input_shape[2] * reshape2_input_shape[3]),
|
190
|
+
]:
|
191
|
+
continue
|
192
|
+
|
193
|
+
### remove redundant reshapes
|
194
|
+
# update permute (remove reshape1)
|
195
|
+
permute.args = (reshape1_input, [1, 0, 2])
|
196
|
+
set_new_meta_val(permute)
|
197
|
+
reshape1.replace_all_uses_with(permute, propagate_meta=False)
|
198
|
+
# update reshape2 args
|
199
|
+
assert permute == reshape2_input
|
200
|
+
reshape2.args = (permute, reshape2_size)
|
201
|
+
|
202
|
+
modified = True
|
203
|
+
logger.debug(f"{reshape1.name} is removed.")
|
204
|
+
|
205
|
+
graph.eliminate_dead_code()
|
206
|
+
graph.lint()
|
207
|
+
graph_module.recompile()
|
208
|
+
|
209
|
+
return PassResult(modified)
|
210
|
+
|
211
|
+
|
212
|
+
@trace_graph_diff_on_pass
|
213
|
+
class RemoveRedundantReshapePattern3(PassBase):
|
214
|
+
def __init__(self):
|
215
|
+
super().__init__()
|
216
|
+
|
217
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
218
|
+
"""
|
219
|
+
[BEFORE]
|
220
|
+
(AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC)
|
221
|
+
(reshape_2) (add) (softmax) (reshape_1)
|
222
|
+
(AxBxC) - aten.reshape - (1xAxBxC) /
|
223
|
+
(reshape_3)
|
224
|
+
[AFTER]
|
225
|
+
(AxBxC) - aten.add - (AxBxC) - aten.softmax - (AxBxC)
|
226
|
+
(AxBxC) / (add) (softmax)
|
227
|
+
"""
|
228
|
+
logger = logging.getLogger(__name__)
|
229
|
+
|
230
|
+
graph_module = exported_program.graph_module
|
231
|
+
graph = graph_module.graph
|
232
|
+
modified = False
|
233
|
+
for reshape_1 in graph.nodes:
|
234
|
+
assert isinstance(reshape_1, torch.fx.Node), type(reshape_1)
|
235
|
+
# reshape_1
|
236
|
+
if not reshape_1.op == "call_function":
|
237
|
+
continue
|
238
|
+
if not reshape_1.target in ops.aten.reshape:
|
239
|
+
continue
|
240
|
+
assert len(reshape_1.args) == 2, len(reshape_1.args)
|
241
|
+
softmax, reshape_1_size = reshape_1.args
|
242
|
+
|
243
|
+
# softmax
|
244
|
+
assert isinstance(softmax, torch.fx.Node), type(softmax)
|
245
|
+
if not softmax.op == "call_function":
|
246
|
+
continue
|
247
|
+
if not softmax.target in ops.aten.softmax:
|
248
|
+
continue
|
249
|
+
assert len(softmax.args) == 3, len(softmax.args)
|
250
|
+
add, softmax_dim, softmax_half_to_float = softmax.args
|
251
|
+
assert isinstance(add, torch.fx.Node), type(add)
|
252
|
+
assert isinstance(softmax_dim, int), type(softmax_dim)
|
253
|
+
assert isinstance(softmax_half_to_float, bool), type(softmax_half_to_float)
|
254
|
+
softmax_shape = extract_shape(softmax)
|
255
|
+
# TODO support other dimension
|
256
|
+
if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
|
257
|
+
continue
|
258
|
+
|
259
|
+
# add
|
260
|
+
if not add.target in ops.aten.add:
|
261
|
+
continue
|
262
|
+
assert len(add.args) == 2, len(add.args)
|
263
|
+
reshape_2, reshape_3 = add.args
|
264
|
+
assert isinstance(reshape_2, torch.fx.Node), type(reshape_2)
|
265
|
+
assert isinstance(reshape_3, torch.fx.Node), type(reshape_3)
|
266
|
+
|
267
|
+
# reshape_2
|
268
|
+
if not reshape_2.op == "call_function":
|
269
|
+
continue
|
270
|
+
if not reshape_2.target in ops.aten.reshape:
|
271
|
+
continue
|
272
|
+
assert len(reshape_2.args) == 2, len(reshape_2.args)
|
273
|
+
reshape_2_input, reshape_2_size = reshape_2.args
|
274
|
+
assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input)
|
275
|
+
# reshape_3
|
276
|
+
if not reshape_3.op == "call_function":
|
277
|
+
continue
|
278
|
+
if not reshape_3.target in ops.aten.reshape:
|
279
|
+
continue
|
280
|
+
assert len(reshape_3.args) == 2, len(reshape_3.args)
|
281
|
+
reshape_3_input, reshape_3_size = reshape_3.args
|
282
|
+
assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input)
|
283
|
+
|
284
|
+
# Check condition
|
285
|
+
reshape_2_input_shape = extract_shape(reshape_2_input)
|
286
|
+
reshape_3_input_shape = extract_shape(reshape_3_input)
|
287
|
+
if reshape_2_input_shape != reshape_3_input_shape:
|
288
|
+
continue
|
289
|
+
reshape_1_shape = extract_shape(reshape_1)
|
290
|
+
if reshape_2_input_shape != reshape_1_shape:
|
291
|
+
continue
|
292
|
+
# Assume `aten.add` and `aten.softmax` have only one user.
|
293
|
+
if len(add.users) != 1:
|
294
|
+
continue
|
295
|
+
if len(softmax.users) != 1:
|
296
|
+
continue
|
297
|
+
|
298
|
+
# Update add
|
299
|
+
add.args = (reshape_2_input, reshape_3_input)
|
300
|
+
set_new_meta_val(add)
|
301
|
+
# Update softmax
|
302
|
+
if softmax_dim == len(softmax_shape) - 1:
|
303
|
+
updated_dim = len(extract_shape(reshape_2_input)) - 1
|
304
|
+
softmax.args = (add, updated_dim, softmax_half_to_float)
|
305
|
+
set_new_meta_val(softmax)
|
306
|
+
|
307
|
+
reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
|
308
|
+
modified = True
|
309
|
+
logger.debug(
|
310
|
+
f"{reshape_2.name}, {reshape_3.name} and {reshape_1.name} are removed."
|
311
|
+
)
|
312
|
+
|
313
|
+
graph.eliminate_dead_code()
|
314
|
+
graph.lint()
|
315
|
+
graph_module.recompile()
|
316
|
+
|
317
|
+
return PassResult(modified)
|
318
|
+
|
319
|
+
|
320
|
+
@trace_graph_diff_on_pass
|
321
|
+
class RemoveRedundantReshapePattern4(PassBase):
|
322
|
+
def __init__(self):
|
323
|
+
super().__init__()
|
324
|
+
|
325
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
326
|
+
"""
|
327
|
+
NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors.
|
328
|
+
What this pattern aims to remove is that the consecutive `aten.reshape` ops.
|
329
|
+
[BEFORE]
|
330
|
+
(AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC')
|
331
|
+
[AFTER]
|
332
|
+
(AxBxC) - aten.reshape - (A'xB''xC')
|
333
|
+
"""
|
334
|
+
logger = logging.getLogger(__name__)
|
335
|
+
|
336
|
+
graph_module = exported_program.graph_module
|
337
|
+
graph = graph_module.graph
|
338
|
+
modified = False
|
339
|
+
for reshape1 in graph.nodes:
|
340
|
+
# reshape_1
|
341
|
+
if not reshape1.op == "call_function":
|
342
|
+
continue
|
343
|
+
if not reshape1.target in ops.aten.reshape:
|
344
|
+
continue
|
345
|
+
assert len(reshape1.args) == 2, len(reshape1.args)
|
346
|
+
|
347
|
+
reshape1_input, size = list(reshape1.args)
|
348
|
+
assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
|
349
|
+
assert isinstance(size, list), type(size)
|
350
|
+
for s in size:
|
351
|
+
assert isinstance(s, int), type(s)
|
352
|
+
|
353
|
+
if not len(reshape1.users) == 1:
|
354
|
+
continue
|
355
|
+
|
356
|
+
# reshape_2
|
357
|
+
reshape2 = next(iter(reshape1.users))
|
358
|
+
if not reshape2.op == "call_function":
|
359
|
+
continue
|
360
|
+
if not reshape2.target in ops.aten.reshape:
|
361
|
+
continue
|
362
|
+
assert len(reshape2.args) == 2, len(reshape2.args)
|
363
|
+
|
364
|
+
reshape2_input, reshape2_size = list(reshape2.args)
|
365
|
+
assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
|
366
|
+
assert isinstance(reshape2_size, list), type(reshape2_size)
|
367
|
+
for s in reshape2_size:
|
368
|
+
assert isinstance(s, int), type(s)
|
369
|
+
|
370
|
+
with graph.inserting_before(reshape1):
|
371
|
+
fused_reshape = graph.call_function(
|
372
|
+
reshape1.target, (reshape1_input, reshape2_size)
|
373
|
+
)
|
374
|
+
|
375
|
+
reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
|
376
|
+
|
377
|
+
modified = True
|
378
|
+
logger.debug(
|
379
|
+
f"{reshape1.name} and {reshape2.name} are fused to {fused_reshape.name}"
|
380
|
+
)
|
381
|
+
|
382
|
+
graph.eliminate_dead_code()
|
383
|
+
graph.lint()
|
384
|
+
graph_module.recompile()
|
385
|
+
|
386
|
+
return PassResult(modified)
|
387
|
+
|
388
|
+
|
389
|
+
@trace_graph_diff_on_pass
|
390
|
+
class RemoveRedundantReshapePattern5(PassBase):
|
391
|
+
def __init__(self):
|
392
|
+
super().__init__()
|
393
|
+
|
394
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
395
|
+
"""
|
396
|
+
[BEFORE]
|
397
|
+
(AxBxC) - aten.reshape - (AxBxC)
|
398
|
+
[AFTER]
|
399
|
+
(AxBxC)
|
400
|
+
"""
|
401
|
+
logger = logging.getLogger(__name__)
|
402
|
+
|
403
|
+
graph_module = exported_program.graph_module
|
404
|
+
graph = graph_module.graph
|
405
|
+
modified = False
|
406
|
+
|
407
|
+
for node in graph.nodes:
|
408
|
+
if not node.op == "call_function":
|
409
|
+
continue
|
410
|
+
|
411
|
+
if not node.target in ops.aten.reshape:
|
412
|
+
continue
|
413
|
+
|
414
|
+
args = ReshapeArgs(*node.args)
|
415
|
+
output_shape = args.size
|
416
|
+
input_shape = list(extract_shape(args.input))
|
417
|
+
|
418
|
+
if output_shape != input_shape:
|
419
|
+
continue
|
420
|
+
|
421
|
+
with graph.inserting_after(node):
|
422
|
+
node.replace_all_uses_with(args.input, propagate_meta=False)
|
423
|
+
|
424
|
+
modified = True
|
425
|
+
logger.debug(f"{node.name} is replaced with {args.input}")
|
426
|
+
|
427
|
+
graph.eliminate_dead_code()
|
428
|
+
graph.lint()
|
429
|
+
graph_module.recompile()
|
430
|
+
|
431
|
+
return PassResult(modified)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from torch.export import ExportedProgram
|
16
|
+
|
17
|
+
from tico.passes import ops
|
18
|
+
from tico.serialize.circle_mapping import extract_shape
|
19
|
+
from tico.utils import logging
|
20
|
+
from tico.utils.passes import PassBase, PassResult
|
21
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
22
|
+
from tico.utils.validate_args_kwargs import SliceArgs
|
23
|
+
|
24
|
+
|
25
|
+
@trace_graph_diff_on_pass
|
26
|
+
class RemoveRedundantSlice(PassBase):
|
27
|
+
"""
|
28
|
+
This pass removes redundant slice operators where shapes of input and output are same.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
graph_module = exported_program.graph_module
|
38
|
+
graph = graph_module.graph
|
39
|
+
modified = False
|
40
|
+
for node in graph.nodes:
|
41
|
+
if not node.op == "call_function":
|
42
|
+
continue
|
43
|
+
|
44
|
+
if not node.target in ops.aten.slice:
|
45
|
+
continue
|
46
|
+
|
47
|
+
args = SliceArgs(*node.args, **node.kwargs)
|
48
|
+
|
49
|
+
input_shape = extract_shape(args.input)
|
50
|
+
node_shape = extract_shape(node)
|
51
|
+
|
52
|
+
if input_shape != node_shape:
|
53
|
+
continue
|
54
|
+
|
55
|
+
node.replace_all_uses_with(args.input, propagate_meta=False)
|
56
|
+
|
57
|
+
modified = True
|
58
|
+
logger.debug(f"{node.name} is replaced with {args.input.name}")
|
59
|
+
|
60
|
+
graph.eliminate_dead_code()
|
61
|
+
graph.lint()
|
62
|
+
graph_module.recompile()
|
63
|
+
|
64
|
+
return PassResult(modified)
|