tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from torch.export import ExportedProgram
|
17
|
+
|
18
|
+
from tico.utils.passes import PassBase, PassResult
|
19
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
20
|
+
from tico.utils.utils import is_target_node
|
21
|
+
|
22
|
+
|
23
|
+
assert_node_targets = [
|
24
|
+
torch.ops.aten._assert_tensor_metadata.default,
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
@trace_graph_diff_on_pass
|
29
|
+
class RemoveRedundantAssertionNodes(PassBase):
|
30
|
+
"""
|
31
|
+
This removes redundant assertion nodes.
|
32
|
+
- `aten.assert_tensor_meta.default`
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
|
38
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
39
|
+
graph_module = exported_program.graph_module
|
40
|
+
graph = graph_module.graph
|
41
|
+
modified = False
|
42
|
+
for node in graph.nodes:
|
43
|
+
if is_target_node(node, assert_node_targets):
|
44
|
+
graph.erase_node(node)
|
45
|
+
modified = True
|
46
|
+
|
47
|
+
graph.eliminate_dead_code()
|
48
|
+
graph.lint()
|
49
|
+
graph_module.recompile()
|
50
|
+
|
51
|
+
return PassResult(modified)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import is_target_node
|
28
|
+
from tico.utils.validate_args_kwargs import ExpandArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class RemoveRedundantExpand(PassBase):
|
33
|
+
"""
|
34
|
+
This pass removes redundant `aten.expand` operators where shapes of input and output are same.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self):
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
graph_module = exported_program.graph_module
|
44
|
+
graph = graph_module.graph
|
45
|
+
modified = False
|
46
|
+
for node in graph.nodes:
|
47
|
+
if not is_target_node(node, ops.aten.expand):
|
48
|
+
continue
|
49
|
+
|
50
|
+
args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
input, size = args.input, args.size
|
52
|
+
|
53
|
+
input_shape = extract_shape(input)
|
54
|
+
if list(input_shape) != size:
|
55
|
+
continue
|
56
|
+
|
57
|
+
node.replace_all_uses_with(input, propagate_meta=False)
|
58
|
+
|
59
|
+
modified = True
|
60
|
+
logger.debug(f"{node.name} is replaced with {input.name}")
|
61
|
+
|
62
|
+
graph.eliminate_dead_code()
|
63
|
+
graph.lint()
|
64
|
+
graph_module.recompile()
|
65
|
+
|
66
|
+
return PassResult(modified)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
29
|
+
from tico.utils.validate_args_kwargs import PermuteArgs
|
30
|
+
|
31
|
+
|
32
|
+
def _compose_permutation(dims1: list[int], dims2: list[int]):
|
33
|
+
"""
|
34
|
+
Compose two permutation vectors.
|
35
|
+
|
36
|
+
Given y = x.permute(dims1) and z = y.permute(dims2),
|
37
|
+
the overall permutation p = dims2 ∘ dims1 is
|
38
|
+
|
39
|
+
p[i] = dims1[dims2[i]]
|
40
|
+
"""
|
41
|
+
assert len(dims1) == len(
|
42
|
+
dims2
|
43
|
+
), f"len(dims1): {len(dims1)}, len(dims2): {len(dims2)}"
|
44
|
+
return [dims1[i] for i in dims2]
|
45
|
+
|
46
|
+
|
47
|
+
def passes():
|
48
|
+
"""
|
49
|
+
Return a list of passes that remove redundant `aten.permute` operators.
|
50
|
+
|
51
|
+
NOTE Both shape and stride of input/output should be same.
|
52
|
+
"""
|
53
|
+
return [
|
54
|
+
RemoveRedundantPermutePattern1(),
|
55
|
+
]
|
56
|
+
|
57
|
+
|
58
|
+
@trace_graph_diff_on_pass
|
59
|
+
class RemoveRedundantPermutePattern1(PassBase):
|
60
|
+
def __init__(self):
|
61
|
+
super().__init__()
|
62
|
+
|
63
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
64
|
+
"""
|
65
|
+
[BEFORE]
|
66
|
+
(AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE)
|
67
|
+
[AFTER]
|
68
|
+
if OUT_SHAPE == (AxBxC):
|
69
|
+
(AxBxC)
|
70
|
+
else:
|
71
|
+
(AxBxC) - aten.permute (fused dims) - (OUT_SHAPE)
|
72
|
+
|
73
|
+
"""
|
74
|
+
logger = logging.getLogger(__name__)
|
75
|
+
|
76
|
+
graph_module = exported_program.graph_module
|
77
|
+
graph = graph_module.graph
|
78
|
+
modified = False
|
79
|
+
for permute2 in graph.nodes:
|
80
|
+
if not is_target_node(permute2, ops.aten.permute):
|
81
|
+
continue
|
82
|
+
|
83
|
+
if len(permute2.users) != 1:
|
84
|
+
continue
|
85
|
+
permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
|
86
|
+
permute1, permute2_dims = permute2_args.input, permute2_args.dims
|
87
|
+
|
88
|
+
if not is_target_node(permute1, ops.aten.permute):
|
89
|
+
continue
|
90
|
+
if len(permute1.users) != 1:
|
91
|
+
continue
|
92
|
+
permute1_args = PermuteArgs(*permute1.args, **permute1.kwargs) # type: ignore[arg-type]
|
93
|
+
permute1_input, permute1_dims = permute1_args.input, permute1_args.dims
|
94
|
+
|
95
|
+
fused_dims = _compose_permutation(permute1_dims, permute2_dims)
|
96
|
+
identity = list(range(len(fused_dims)))
|
97
|
+
|
98
|
+
if fused_dims == identity:
|
99
|
+
# shape
|
100
|
+
permute1_input_shape = extract_shape(permute1_input)
|
101
|
+
permute2_shape = extract_shape(permute2)
|
102
|
+
assert permute1_input_shape == permute2_shape
|
103
|
+
|
104
|
+
permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
|
105
|
+
logger.debug(f"{permute1.name} and {permute2.name} are removed.")
|
106
|
+
else:
|
107
|
+
with graph.inserting_after(permute2):
|
108
|
+
new_args = (permute1_input, fused_dims)
|
109
|
+
fused_permute = create_node(
|
110
|
+
graph,
|
111
|
+
torch.ops.aten.permute.default,
|
112
|
+
args=new_args,
|
113
|
+
)
|
114
|
+
permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
|
115
|
+
logger.debug(f"{permute1.name} and {permute2.name} are fused.")
|
116
|
+
modified = True
|
117
|
+
|
118
|
+
graph.eliminate_dead_code()
|
119
|
+
graph.lint()
|
120
|
+
graph_module.recompile()
|
121
|
+
|
122
|
+
return PassResult(modified)
|
@@ -0,0 +1,436 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import broadcastable, is_target_node, set_new_meta_val
|
29
|
+
from tico.utils.validate_args_kwargs import (
|
30
|
+
AddTensorArgs,
|
31
|
+
PermuteArgs,
|
32
|
+
ReshapeArgs,
|
33
|
+
SafeSoftmaxArgs,
|
34
|
+
SoftmaxArgs,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def passes():
|
39
|
+
"""
|
40
|
+
Return list of passes that remove redundant `aten.reshape` operators.
|
41
|
+
"""
|
42
|
+
return [
|
43
|
+
RemoveRedundantReshapePattern1(),
|
44
|
+
RemoveRedundantReshapePattern2(),
|
45
|
+
RemoveRedundantReshapePattern3(),
|
46
|
+
RemoveRedundantReshapePattern4(),
|
47
|
+
RemoveRedundantReshapePattern5(),
|
48
|
+
]
|
49
|
+
|
50
|
+
|
51
|
+
@trace_graph_diff_on_pass
|
52
|
+
class RemoveRedundantReshapePattern1(PassBase):
|
53
|
+
mul_ops: List[torch._ops.OpOverload] = ops.aten.mul_scalar + ops.aten.mul_tensor
|
54
|
+
|
55
|
+
def __init__(self):
|
56
|
+
super().__init__()
|
57
|
+
|
58
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
59
|
+
"""
|
60
|
+
[BEFORE]
|
61
|
+
`(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)`
|
62
|
+
[AFTER]
|
63
|
+
`(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)`
|
64
|
+
"""
|
65
|
+
logger = logging.getLogger(__name__)
|
66
|
+
|
67
|
+
graph_module = exported_program.graph_module
|
68
|
+
graph = graph_module.graph
|
69
|
+
modified = False
|
70
|
+
for reshape1 in graph.nodes:
|
71
|
+
### first reshape
|
72
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
73
|
+
continue
|
74
|
+
|
75
|
+
# Assumes that other node do not use ops in the pattern for simplisity.
|
76
|
+
if len(reshape1.users) != 1:
|
77
|
+
continue
|
78
|
+
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
79
|
+
reshape1_input = reshape1_args.input
|
80
|
+
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
81
|
+
if [1] + list(extract_shape(reshape1_input)) != list(
|
82
|
+
extract_shape(reshape1)
|
83
|
+
):
|
84
|
+
continue
|
85
|
+
|
86
|
+
### permute
|
87
|
+
permute = next(iter(reshape1.users))
|
88
|
+
if not is_target_node(permute, ops.aten.permute):
|
89
|
+
continue
|
90
|
+
if len(permute.users) != 1:
|
91
|
+
continue
|
92
|
+
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
93
|
+
permute_input, permute_dims = permute_args.input, permute_args.dims
|
94
|
+
# (1xAxBxC) - `aten.permute` - (1xAxCxB)
|
95
|
+
if permute_dims != [0, 1, 3, 2]:
|
96
|
+
continue
|
97
|
+
|
98
|
+
### mul
|
99
|
+
mul = next(iter(permute.users))
|
100
|
+
if not is_target_node(mul, RemoveRedundantReshapePattern1.mul_ops):
|
101
|
+
continue
|
102
|
+
if len(mul.users) != 1:
|
103
|
+
continue
|
104
|
+
|
105
|
+
### second reshape
|
106
|
+
reshape2 = next(iter(mul.users))
|
107
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
108
|
+
continue
|
109
|
+
if len(reshape2.users) != 1:
|
110
|
+
continue
|
111
|
+
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
112
|
+
reshape2_input = reshape2_args.input
|
113
|
+
# (1xAxCxB) - `aten.reshape - (AxCxB)
|
114
|
+
if list(extract_shape(reshape2_input)) != [1] + list(
|
115
|
+
extract_shape(reshape2)
|
116
|
+
):
|
117
|
+
continue
|
118
|
+
|
119
|
+
### remove redundant reshapes
|
120
|
+
# update permute (remove reshape1)
|
121
|
+
permute.args = (reshape1_input, [0, 2, 1])
|
122
|
+
set_new_meta_val(permute)
|
123
|
+
set_new_meta_val(mul)
|
124
|
+
# remove reshape2
|
125
|
+
reshape2.replace_all_uses_with(mul, propagate_meta=False)
|
126
|
+
|
127
|
+
modified = True
|
128
|
+
logger.debug(f"{reshape1.name} and {reshape2.name} are removed.")
|
129
|
+
|
130
|
+
graph.eliminate_dead_code()
|
131
|
+
graph.lint()
|
132
|
+
graph_module.recompile()
|
133
|
+
|
134
|
+
return PassResult(modified)
|
135
|
+
|
136
|
+
|
137
|
+
@trace_graph_diff_on_pass
|
138
|
+
class RemoveRedundantReshapePattern2(PassBase):
|
139
|
+
def __init__(self):
|
140
|
+
super().__init__()
|
141
|
+
|
142
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
143
|
+
"""
|
144
|
+
[BEFORE]
|
145
|
+
`(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))`
|
146
|
+
[AFTER]
|
147
|
+
`(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))`
|
148
|
+
"""
|
149
|
+
logger = logging.getLogger(__name__)
|
150
|
+
|
151
|
+
graph_module = exported_program.graph_module
|
152
|
+
graph = graph_module.graph
|
153
|
+
modified = False
|
154
|
+
for reshape1 in graph.nodes:
|
155
|
+
### first reshape
|
156
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
157
|
+
continue
|
158
|
+
if len(reshape1.users) != 1:
|
159
|
+
continue
|
160
|
+
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
161
|
+
reshape1_input = reshape1_args.input
|
162
|
+
# `(AxBxC) - aten.reshape` - (1xAxBxC)
|
163
|
+
if [1] + list(extract_shape(reshape1_input)) != list(
|
164
|
+
extract_shape(reshape1)
|
165
|
+
):
|
166
|
+
continue
|
167
|
+
|
168
|
+
### permute
|
169
|
+
permute = next(iter(reshape1.users))
|
170
|
+
if not is_target_node(permute, ops.aten.permute):
|
171
|
+
continue
|
172
|
+
if len(permute.users) != 1:
|
173
|
+
continue
|
174
|
+
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
175
|
+
permute_input, permute_dims = permute_args.input, permute_args.dims
|
176
|
+
# (1xAxBxC) - `aten.permute` - (Bx1xAxC)
|
177
|
+
if permute_dims != [2, 0, 1, 3]:
|
178
|
+
continue
|
179
|
+
|
180
|
+
### second reshape
|
181
|
+
reshape2 = next(iter(permute.users))
|
182
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
183
|
+
continue
|
184
|
+
if len(reshape2.users) != 1:
|
185
|
+
continue
|
186
|
+
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
187
|
+
reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
|
188
|
+
# (Bx1xAxC) - `aten.reshape - (Bx(A*C))
|
189
|
+
reshape2_input_shape = list(extract_shape(reshape2_input))
|
190
|
+
assert len(reshape2_input_shape) == 4
|
191
|
+
if list(extract_shape(reshape2)) != [
|
192
|
+
reshape2_input_shape[0],
|
193
|
+
(reshape2_input_shape[2] * reshape2_input_shape[3]),
|
194
|
+
]:
|
195
|
+
continue
|
196
|
+
|
197
|
+
### remove redundant reshapes
|
198
|
+
# update permute (remove reshape1)
|
199
|
+
permute.args = (reshape1_input, [1, 0, 2])
|
200
|
+
set_new_meta_val(permute)
|
201
|
+
reshape1.replace_all_uses_with(permute, propagate_meta=False)
|
202
|
+
# update reshape2 args
|
203
|
+
assert permute == reshape2_input
|
204
|
+
reshape2.args = (permute, reshape2_size)
|
205
|
+
|
206
|
+
modified = True
|
207
|
+
logger.debug(f"{reshape1.name} is removed.")
|
208
|
+
|
209
|
+
graph.eliminate_dead_code()
|
210
|
+
graph.lint()
|
211
|
+
graph_module.recompile()
|
212
|
+
|
213
|
+
return PassResult(modified)
|
214
|
+
|
215
|
+
|
216
|
+
@trace_graph_diff_on_pass
|
217
|
+
class RemoveRedundantReshapePattern3(PassBase):
|
218
|
+
def __init__(self):
|
219
|
+
super().__init__()
|
220
|
+
|
221
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
222
|
+
"""
|
223
|
+
[BEFORE]
|
224
|
+
(AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC)
|
225
|
+
(reshape_2) (add) (softmax) (reshape_1)
|
226
|
+
(AxBxC) - aten.reshape - (1xAxBxC) /
|
227
|
+
(reshape_3)
|
228
|
+
[AFTER]
|
229
|
+
(AxBxC) - aten.add - (AxBxC) - aten.softmax - (AxBxC)
|
230
|
+
(AxBxC) / (add) (softmax)
|
231
|
+
"""
|
232
|
+
logger = logging.getLogger(__name__)
|
233
|
+
|
234
|
+
graph_module = exported_program.graph_module
|
235
|
+
graph = graph_module.graph
|
236
|
+
modified = False
|
237
|
+
for reshape_1 in graph.nodes:
|
238
|
+
# reshape_1
|
239
|
+
if not is_target_node(reshape_1, ops.aten.reshape):
|
240
|
+
continue
|
241
|
+
reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type]
|
242
|
+
softmax = reshape_1_args.input
|
243
|
+
|
244
|
+
# softmax
|
245
|
+
softmax_args = None
|
246
|
+
if not is_target_node(softmax, ops.aten.softmax):
|
247
|
+
continue
|
248
|
+
if softmax.target == torch.ops.aten._softmax.default:
|
249
|
+
softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
|
250
|
+
elif softmax.target == torch.ops.aten._safe_softmax.default:
|
251
|
+
softmax_args = SafeSoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
|
252
|
+
else:
|
253
|
+
raise RuntimeError("Invalid input")
|
254
|
+
assert softmax_args is not None
|
255
|
+
add, softmax_dim = (
|
256
|
+
softmax_args.input,
|
257
|
+
softmax_args.dim,
|
258
|
+
)
|
259
|
+
softmax_shape = extract_shape(softmax)
|
260
|
+
# TODO support other dimension
|
261
|
+
if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
|
262
|
+
continue
|
263
|
+
|
264
|
+
# add
|
265
|
+
if not add.target in ops.aten.add:
|
266
|
+
continue
|
267
|
+
add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
|
268
|
+
reshape_2, reshape_3 = add_args.input, add_args.other
|
269
|
+
assert isinstance(reshape_2, torch.fx.Node), type(reshape_2)
|
270
|
+
assert isinstance(reshape_3, torch.fx.Node), type(reshape_3)
|
271
|
+
|
272
|
+
# reshape_2
|
273
|
+
if not reshape_2.op == "call_function":
|
274
|
+
continue
|
275
|
+
if not reshape_2.target in ops.aten.reshape:
|
276
|
+
continue
|
277
|
+
reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
|
278
|
+
reshape_2_input = reshape_2_args.input
|
279
|
+
assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input)
|
280
|
+
# reshape_3
|
281
|
+
if not reshape_3.op == "call_function":
|
282
|
+
continue
|
283
|
+
if not reshape_3.target in ops.aten.reshape:
|
284
|
+
continue
|
285
|
+
reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
|
286
|
+
reshape_3_input = reshape_3_args.input
|
287
|
+
assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input)
|
288
|
+
|
289
|
+
# Check condition
|
290
|
+
reshape_2_input_shape = extract_shape(reshape_2_input)
|
291
|
+
reshape_3_input_shape = extract_shape(reshape_3_input)
|
292
|
+
if not broadcastable(reshape_2_input_shape, reshape_3_input_shape):
|
293
|
+
continue
|
294
|
+
reshape_1_shape = extract_shape(reshape_1)
|
295
|
+
if (
|
296
|
+
reshape_2_input_shape != reshape_1_shape
|
297
|
+
and reshape_3_input_shape != reshape_1_shape
|
298
|
+
):
|
299
|
+
continue
|
300
|
+
# Make sure the softmax axis length is unchanged.
|
301
|
+
if softmax_shape[-1] != reshape_1_shape[-1]:
|
302
|
+
continue
|
303
|
+
# Assume `aten.add` and `aten.softmax` have only one user.
|
304
|
+
if len(add.users) != 1:
|
305
|
+
continue
|
306
|
+
if len(softmax.users) != 1:
|
307
|
+
continue
|
308
|
+
|
309
|
+
# Update add
|
310
|
+
add.args = (reshape_2_input, reshape_3_input)
|
311
|
+
set_new_meta_val(add)
|
312
|
+
# Update softmax
|
313
|
+
if softmax_dim == len(softmax_shape) - 1:
|
314
|
+
softmax.update_arg(1, -1) # (index, last_dim)
|
315
|
+
set_new_meta_val(softmax)
|
316
|
+
|
317
|
+
reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
|
318
|
+
modified = True
|
319
|
+
logger.debug(
|
320
|
+
f"{reshape_2.name}, {reshape_3.name} and {reshape_1.name} are removed."
|
321
|
+
)
|
322
|
+
|
323
|
+
graph.eliminate_dead_code()
|
324
|
+
graph.lint()
|
325
|
+
graph_module.recompile()
|
326
|
+
|
327
|
+
return PassResult(modified)
|
328
|
+
|
329
|
+
|
330
|
+
@trace_graph_diff_on_pass
|
331
|
+
class RemoveRedundantReshapePattern4(PassBase):
|
332
|
+
def __init__(self):
|
333
|
+
super().__init__()
|
334
|
+
|
335
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
336
|
+
"""
|
337
|
+
NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors.
|
338
|
+
What this pattern aims to remove is that the consecutive `aten.reshape` ops.
|
339
|
+
[BEFORE]
|
340
|
+
(AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC')
|
341
|
+
[AFTER]
|
342
|
+
(AxBxC) - aten.reshape - (A'xB''xC')
|
343
|
+
"""
|
344
|
+
logger = logging.getLogger(__name__)
|
345
|
+
|
346
|
+
graph_module = exported_program.graph_module
|
347
|
+
graph = graph_module.graph
|
348
|
+
modified = False
|
349
|
+
for reshape1 in graph.nodes:
|
350
|
+
# reshape_1
|
351
|
+
if not is_target_node(reshape1, ops.aten.reshape):
|
352
|
+
continue
|
353
|
+
|
354
|
+
reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
|
355
|
+
reshape1_input, size = reshape1_args.input, reshape1_args.shape
|
356
|
+
assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
|
357
|
+
assert isinstance(size, list), type(size)
|
358
|
+
for s in size:
|
359
|
+
assert isinstance(s, int), type(s)
|
360
|
+
|
361
|
+
if not len(reshape1.users) == 1:
|
362
|
+
continue
|
363
|
+
|
364
|
+
# reshape_2
|
365
|
+
reshape2 = next(iter(reshape1.users))
|
366
|
+
if not is_target_node(reshape2, ops.aten.reshape):
|
367
|
+
continue
|
368
|
+
|
369
|
+
reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type]
|
370
|
+
reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape
|
371
|
+
assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input)
|
372
|
+
assert isinstance(reshape2_size, list), type(reshape2_size)
|
373
|
+
for s in reshape2_size:
|
374
|
+
assert isinstance(s, int), type(s)
|
375
|
+
|
376
|
+
with graph.inserting_before(reshape1):
|
377
|
+
fused_reshape = create_node(
|
378
|
+
graph,
|
379
|
+
reshape1.target,
|
380
|
+
(reshape1_input, reshape2_size),
|
381
|
+
)
|
382
|
+
|
383
|
+
reshape2.replace_all_uses_with(fused_reshape, propagate_meta=True)
|
384
|
+
|
385
|
+
modified = True
|
386
|
+
logger.debug(
|
387
|
+
f"{reshape1.name} and {reshape2.name} are fused to {fused_reshape.name}"
|
388
|
+
)
|
389
|
+
|
390
|
+
graph.eliminate_dead_code()
|
391
|
+
graph.lint()
|
392
|
+
graph_module.recompile()
|
393
|
+
|
394
|
+
return PassResult(modified)
|
395
|
+
|
396
|
+
|
397
|
+
@trace_graph_diff_on_pass
|
398
|
+
class RemoveRedundantReshapePattern5(PassBase):
|
399
|
+
def __init__(self):
|
400
|
+
super().__init__()
|
401
|
+
|
402
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
403
|
+
"""
|
404
|
+
[BEFORE]
|
405
|
+
(AxBxC) - aten.reshape - (AxBxC)
|
406
|
+
[AFTER]
|
407
|
+
(AxBxC)
|
408
|
+
"""
|
409
|
+
logger = logging.getLogger(__name__)
|
410
|
+
|
411
|
+
graph_module = exported_program.graph_module
|
412
|
+
graph = graph_module.graph
|
413
|
+
modified = False
|
414
|
+
|
415
|
+
for node in graph.nodes:
|
416
|
+
if not is_target_node(node, ops.aten.reshape):
|
417
|
+
continue
|
418
|
+
|
419
|
+
args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
420
|
+
output_shape = args.shape
|
421
|
+
input_shape = list(extract_shape(args.input))
|
422
|
+
|
423
|
+
if output_shape != input_shape:
|
424
|
+
continue
|
425
|
+
|
426
|
+
with graph.inserting_after(node):
|
427
|
+
node.replace_all_uses_with(args.input, propagate_meta=False)
|
428
|
+
|
429
|
+
modified = True
|
430
|
+
logger.debug(f"{node.name} is replaced with {args.input}")
|
431
|
+
|
432
|
+
graph.eliminate_dead_code()
|
433
|
+
graph.lint()
|
434
|
+
graph_module.recompile()
|
435
|
+
|
436
|
+
return PassResult(modified)
|