tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import create_node
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
29
|
+
from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class LowerToResizeNearestNeighbor(PassBase):
|
34
|
+
"""
|
35
|
+
This pass lowers `aten.index` and `aten.upsample_nearest2d.vec` to `circle_custom.resize_nearest_neighbor` when it is possible.
|
36
|
+
|
37
|
+
Until torch 2.7, `torch.nn.functional.interpolate` is converted to `aten.index` op.
|
38
|
+
[BEFORE PASS]
|
39
|
+
input - aten.index - output
|
40
|
+
|
41
|
+
[AFTER PASS]
|
42
|
+
input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
|
43
|
+
|
44
|
+
Since torch 2.8, `torch.nn.functional.interpolate` is converted to aten.upsample_nearest2d.vec` op.
|
45
|
+
[BEFORE PASS]
|
46
|
+
input - aten.upsample_nearest2d.vec - output
|
47
|
+
|
48
|
+
[AFTER PASS]
|
49
|
+
input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(self):
|
53
|
+
super().__init__()
|
54
|
+
|
55
|
+
def convert_index_to_resize_nearest_neighbor(
|
56
|
+
self, exported_program, node
|
57
|
+
) -> Optional[torch.fx.Node]:
|
58
|
+
graph_module = exported_program.graph_module
|
59
|
+
graph = graph_module.graph
|
60
|
+
|
61
|
+
args = IndexArgs(*node.args, **node.kwargs)
|
62
|
+
input_tensor = args.input
|
63
|
+
indices = args.indices
|
64
|
+
|
65
|
+
# Only support 4-D tensor
|
66
|
+
if len(indices) != 4:
|
67
|
+
return None
|
68
|
+
# indices = [None, None, H index, W index]
|
69
|
+
N, C, H, W = indices
|
70
|
+
if N != None or C != None:
|
71
|
+
return None
|
72
|
+
if not isinstance(H, torch.fx.Node):
|
73
|
+
return None
|
74
|
+
if not isinstance(W, torch.fx.Node):
|
75
|
+
return None
|
76
|
+
constants_dict = exported_program.constants
|
77
|
+
if (H.name not in constants_dict) or (W.name not in constants_dict):
|
78
|
+
return None
|
79
|
+
H_index, W_index = constants_dict[H.name], constants_dict[W.name]
|
80
|
+
input_tensor_shape = extract_shape(input_tensor)
|
81
|
+
input_tensor_H, input_tensor_W = (
|
82
|
+
input_tensor_shape[2],
|
83
|
+
input_tensor_shape[3],
|
84
|
+
)
|
85
|
+
if H_index.size()[0] % input_tensor_H != 0:
|
86
|
+
return None
|
87
|
+
scale_factor = int(H_index.size()[0] / input_tensor_H)
|
88
|
+
# H and W should be resized with same ratio.
|
89
|
+
if scale_factor != W_index.size()[0] / input_tensor_W:
|
90
|
+
return None
|
91
|
+
expected_H_index = []
|
92
|
+
expected_W_index = []
|
93
|
+
# Please refer to above `_prop_tensor_constant1` constant in the example.
|
94
|
+
for i in range(input_tensor_H):
|
95
|
+
expected_H_index += [[i]] * scale_factor
|
96
|
+
# Please refer to above `_prop_tensor_constant0` constant in the example.
|
97
|
+
for i in range(input_tensor_W):
|
98
|
+
expected_W_index += [i] * scale_factor
|
99
|
+
if not torch.all(
|
100
|
+
torch.eq(H_index, torch.tensor(expected_H_index))
|
101
|
+
) or not torch.all(torch.eq(W_index, torch.tensor(expected_W_index))):
|
102
|
+
return None
|
103
|
+
expected_shape = [
|
104
|
+
input_tensor_shape[0],
|
105
|
+
input_tensor_shape[1],
|
106
|
+
len(expected_H_index),
|
107
|
+
len(expected_W_index),
|
108
|
+
]
|
109
|
+
assert expected_shape == list(extract_shape(node))
|
110
|
+
|
111
|
+
with graph.inserting_before(node):
|
112
|
+
nchw_to_nhwc = create_node(
|
113
|
+
graph,
|
114
|
+
torch.ops.aten.permute.default,
|
115
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
116
|
+
origin=input_tensor,
|
117
|
+
)
|
118
|
+
resize_nearest_neighbor = create_node(
|
119
|
+
graph,
|
120
|
+
torch.ops.circle_custom.resize_nearest_neighbor,
|
121
|
+
args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
|
122
|
+
origin=node,
|
123
|
+
)
|
124
|
+
nhwc_to_nchw = create_node(
|
125
|
+
graph,
|
126
|
+
torch.ops.aten.permute.default,
|
127
|
+
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
128
|
+
)
|
129
|
+
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
130
|
+
|
131
|
+
return resize_nearest_neighbor
|
132
|
+
|
133
|
+
def convert_upsample_nearest2d_to_resize_nearest_neighbor(
|
134
|
+
self, exported_program, node
|
135
|
+
) -> Optional[torch.fx.Node]:
|
136
|
+
graph_module = exported_program.graph_module
|
137
|
+
graph = graph_module.graph
|
138
|
+
|
139
|
+
args = UpsampleNearest2DVecArgs(*node.args, **node.kwargs)
|
140
|
+
input_tensor = args.input
|
141
|
+
output_size = args.output_size
|
142
|
+
scale_factors = args.scale_factors
|
143
|
+
|
144
|
+
input_tensor_shape = extract_shape(input_tensor)
|
145
|
+
input_tensor_H, input_tensor_W = (
|
146
|
+
input_tensor_shape[2],
|
147
|
+
input_tensor_shape[3],
|
148
|
+
)
|
149
|
+
|
150
|
+
if output_size is not None:
|
151
|
+
raise NotYetSupportedError("output_size is not supported yet")
|
152
|
+
|
153
|
+
if scale_factors is None:
|
154
|
+
raise NotYetSupportedError("scale_factors is None")
|
155
|
+
# TODO Support output_size case. Currently only scale_factors case is supported.
|
156
|
+
|
157
|
+
assert (
|
158
|
+
isinstance(scale_factors[0], float)
|
159
|
+
and isinstance(scale_factors[1], float)
|
160
|
+
and scale_factors[0] > 0
|
161
|
+
and scale_factors[1] > 0
|
162
|
+
)
|
163
|
+
|
164
|
+
def close_enough(x, y, epsilon=1e-10):
|
165
|
+
return abs(x - y) < epsilon
|
166
|
+
|
167
|
+
expected_H = int(input_tensor_H * scale_factors[0])
|
168
|
+
if not close_enough(expected_H, input_tensor_H * scale_factors[0]):
|
169
|
+
raise NotYetSupportedError(
|
170
|
+
f"Cannot support input_tensor_H ({input_tensor_H}) with scaling factor ({scale_factors[0]})"
|
171
|
+
)
|
172
|
+
|
173
|
+
expected_W = int(input_tensor_W * scale_factors[1])
|
174
|
+
if not close_enough(expected_W, input_tensor_W * scale_factors[1]):
|
175
|
+
raise NotYetSupportedError(
|
176
|
+
f"Cannot support input_tensor_W ({input_tensor_W}) with scaling factor ({scale_factors[1]})"
|
177
|
+
)
|
178
|
+
|
179
|
+
with graph.inserting_before(node):
|
180
|
+
nchw_to_nhwc = create_node(
|
181
|
+
graph,
|
182
|
+
torch.ops.aten.permute.default,
|
183
|
+
args=(input_tensor, [0, 2, 3, 1]),
|
184
|
+
origin=input_tensor,
|
185
|
+
)
|
186
|
+
resize_nearest_neighbor = create_node(
|
187
|
+
graph,
|
188
|
+
torch.ops.circle_custom.resize_nearest_neighbor,
|
189
|
+
args=(nchw_to_nhwc, [expected_H, expected_W]),
|
190
|
+
origin=node,
|
191
|
+
)
|
192
|
+
nhwc_to_nchw = create_node(
|
193
|
+
graph,
|
194
|
+
torch.ops.aten.permute.default,
|
195
|
+
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
196
|
+
)
|
197
|
+
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
198
|
+
return resize_nearest_neighbor
|
199
|
+
|
200
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
201
|
+
logger = logging.getLogger(__name__)
|
202
|
+
|
203
|
+
modified = False
|
204
|
+
graph_module = exported_program.graph_module
|
205
|
+
graph = graph_module.graph
|
206
|
+
for node in graph.nodes:
|
207
|
+
if not is_target_node(
|
208
|
+
node,
|
209
|
+
[torch.ops.aten.index.Tensor, torch.ops.aten.upsample_nearest2d.vec],
|
210
|
+
):
|
211
|
+
continue
|
212
|
+
|
213
|
+
resize_nearest_neighbor = None
|
214
|
+
if node.target == torch.ops.aten.index.Tensor:
|
215
|
+
resize_nearest_neighbor = self.convert_index_to_resize_nearest_neighbor(
|
216
|
+
exported_program, node
|
217
|
+
)
|
218
|
+
elif node.target == torch.ops.aten.upsample_nearest2d.vec:
|
219
|
+
resize_nearest_neighbor = (
|
220
|
+
self.convert_upsample_nearest2d_to_resize_nearest_neighbor(
|
221
|
+
exported_program, node
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
if resize_nearest_neighbor:
|
226
|
+
modified = True
|
227
|
+
logger.debug(
|
228
|
+
f"{node.name} is replaced with {resize_nearest_neighbor.name} operator"
|
229
|
+
)
|
230
|
+
|
231
|
+
graph.eliminate_dead_code()
|
232
|
+
graph.lint()
|
233
|
+
graph_module.recompile()
|
234
|
+
|
235
|
+
return PassResult(modified)
|
@@ -0,0 +1,230 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch._export.utils import (
|
21
|
+
get_buffer,
|
22
|
+
get_lifted_tensor_constant,
|
23
|
+
get_param,
|
24
|
+
is_buffer,
|
25
|
+
is_lifted_tensor_constant,
|
26
|
+
is_param,
|
27
|
+
)
|
28
|
+
from torch.export import ExportedProgram
|
29
|
+
|
30
|
+
from tico.passes import ops
|
31
|
+
from tico.serialize.circle_graph import extract_shape
|
32
|
+
from tico.utils import logging
|
33
|
+
from tico.utils.graph import create_node, is_single_value_tensor
|
34
|
+
from tico.utils.passes import PassBase, PassResult
|
35
|
+
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
36
|
+
from tico.utils.utils import is_target_node
|
37
|
+
from tico.utils.validate_args_kwargs import IndexSelectArgs, SelectCopyIntArgs
|
38
|
+
|
39
|
+
|
40
|
+
def passes():
|
41
|
+
"""
|
42
|
+
This pass lowers aten.ops.select/selct_copy.int to aten.ops.slice.
|
43
|
+
We support only when it is index in args, which is a constant tensor.
|
44
|
+
Since the index in node'args isn't constant tensor, we can't support converting the below op list yet.
|
45
|
+
|
46
|
+
TODO Support below with const indices
|
47
|
+
- torch.ops.aten.embedding.default
|
48
|
+
- torch.ops.aten.index.Tensor
|
49
|
+
"""
|
50
|
+
return [
|
51
|
+
LowerSelectCopyToSlice(),
|
52
|
+
LowerIndexSelectToSlice(),
|
53
|
+
]
|
54
|
+
|
55
|
+
|
56
|
+
@trace_const_diff_on_pass
|
57
|
+
class LowerSelectCopyToSlice(PassBase):
|
58
|
+
"""
|
59
|
+
[before]
|
60
|
+
input
|
61
|
+
|
|
62
|
+
select (tensor, dim, *index)
|
63
|
+
|
|
64
|
+
output
|
65
|
+
|
66
|
+
[after]
|
67
|
+
|
68
|
+
input
|
69
|
+
|
|
70
|
+
slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
|
71
|
+
|
|
72
|
+
reshape (input=slice_copy, size=select_shape)
|
73
|
+
|
|
74
|
+
output
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(self):
|
78
|
+
super().__init__()
|
79
|
+
|
80
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
81
|
+
logger = logging.getLogger(__name__)
|
82
|
+
|
83
|
+
graph_module = exported_program.graph_module
|
84
|
+
graph = graph_module.graph
|
85
|
+
modified = False
|
86
|
+
for node in graph.nodes:
|
87
|
+
if not is_target_node(node, ops.aten.select):
|
88
|
+
continue
|
89
|
+
|
90
|
+
args = SelectCopyIntArgs(*node.args, **node.kwargs)
|
91
|
+
input = args.input
|
92
|
+
dim = args.dim
|
93
|
+
index = args.index
|
94
|
+
|
95
|
+
input_shape = extract_shape(input)
|
96
|
+
if dim < 0:
|
97
|
+
dim = dim % len(input_shape)
|
98
|
+
|
99
|
+
start = index
|
100
|
+
end = index + 1
|
101
|
+
step = 1
|
102
|
+
slice_copy_args = (input, dim, start, end, step)
|
103
|
+
|
104
|
+
with graph.inserting_after(node):
|
105
|
+
# slice
|
106
|
+
slice_node = create_node(
|
107
|
+
graph,
|
108
|
+
torch.ops.aten.slice.Tensor,
|
109
|
+
args=slice_copy_args,
|
110
|
+
origin=node,
|
111
|
+
)
|
112
|
+
node_shape = extract_shape(node)
|
113
|
+
with graph.inserting_after(slice_node):
|
114
|
+
# reshape
|
115
|
+
reshape_args = (slice_node, list(node_shape))
|
116
|
+
reshape_node = create_node(
|
117
|
+
graph,
|
118
|
+
torch.ops.aten.reshape.default,
|
119
|
+
args=reshape_args,
|
120
|
+
)
|
121
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
122
|
+
|
123
|
+
modified = True
|
124
|
+
logger.debug(
|
125
|
+
f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
|
126
|
+
)
|
127
|
+
|
128
|
+
graph.eliminate_dead_code()
|
129
|
+
graph.lint()
|
130
|
+
graph_module.recompile()
|
131
|
+
|
132
|
+
return PassResult(modified)
|
133
|
+
|
134
|
+
|
135
|
+
@trace_const_diff_on_pass
|
136
|
+
class LowerIndexSelectToSlice(PassBase):
|
137
|
+
"""
|
138
|
+
|
139
|
+
[before]
|
140
|
+
input
|
141
|
+
|
|
142
|
+
index_select.default (tensor, dim, *index)
|
143
|
+
|
|
144
|
+
output
|
145
|
+
|
146
|
+
[after]
|
147
|
+
|
148
|
+
input
|
149
|
+
|
|
150
|
+
slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
|
151
|
+
|
|
152
|
+
reshape (input=slice_copy, size=select_shape)
|
153
|
+
|
|
154
|
+
output
|
155
|
+
"""
|
156
|
+
|
157
|
+
def __init__(self):
|
158
|
+
super().__init__()
|
159
|
+
|
160
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
161
|
+
logger = logging.getLogger(__name__)
|
162
|
+
|
163
|
+
graph_module = exported_program.graph_module
|
164
|
+
graph = graph_module.graph
|
165
|
+
modified = False
|
166
|
+
for node in graph.nodes:
|
167
|
+
if not is_target_node(node, ops.aten.index_select):
|
168
|
+
continue
|
169
|
+
|
170
|
+
args = IndexSelectArgs(*node.args, **node.kwargs)
|
171
|
+
input = args.input
|
172
|
+
dim = args.dim
|
173
|
+
index = args.index
|
174
|
+
|
175
|
+
input_shape = extract_shape(input)
|
176
|
+
if dim < 0:
|
177
|
+
dim = dim % len(input_shape)
|
178
|
+
|
179
|
+
if isinstance(index, torch.fx.Node):
|
180
|
+
if is_lifted_tensor_constant(exported_program, index):
|
181
|
+
index = get_lifted_tensor_constant(exported_program, index) # type: ignore[assignment]
|
182
|
+
elif is_param(exported_program, index):
|
183
|
+
index = get_param(exported_program, index) # type: ignore[assignment]
|
184
|
+
elif is_buffer(exported_program, index):
|
185
|
+
index = get_buffer(exported_program, index) # type: ignore[assignment]
|
186
|
+
else:
|
187
|
+
continue
|
188
|
+
|
189
|
+
if not isinstance(index, torch.Tensor):
|
190
|
+
continue
|
191
|
+
|
192
|
+
if not is_single_value_tensor(index):
|
193
|
+
# need to be lowered by LowerIndexSelect pass
|
194
|
+
continue
|
195
|
+
index_int = index.item() # convert scalar tensor to int
|
196
|
+
|
197
|
+
start = index_int
|
198
|
+
end = index_int + 1
|
199
|
+
step = 1
|
200
|
+
slice_copy_args = (input, dim, start, end, step)
|
201
|
+
|
202
|
+
with graph.inserting_after(node):
|
203
|
+
# slice
|
204
|
+
slice_node = create_node(
|
205
|
+
graph,
|
206
|
+
torch.ops.aten.slice.Tensor,
|
207
|
+
args=slice_copy_args,
|
208
|
+
origin=node,
|
209
|
+
)
|
210
|
+
node_shape = extract_shape(node)
|
211
|
+
with graph.inserting_after(slice_node):
|
212
|
+
# reshape
|
213
|
+
reshape_args = (slice_node, list(node_shape))
|
214
|
+
reshape_node = create_node(
|
215
|
+
graph,
|
216
|
+
torch.ops.aten.reshape.default,
|
217
|
+
args=reshape_args,
|
218
|
+
)
|
219
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
220
|
+
|
221
|
+
modified = True
|
222
|
+
logger.debug(
|
223
|
+
f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
|
224
|
+
)
|
225
|
+
|
226
|
+
graph.eliminate_dead_code()
|
227
|
+
graph.lint()
|
228
|
+
graph_module.recompile()
|
229
|
+
|
230
|
+
return PassResult(modified)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from torch.export import ExportedProgram
|
16
|
+
|
17
|
+
from tico.passes import ops
|
18
|
+
from tico.utils import logging
|
19
|
+
from tico.utils.passes import PassBase, PassResult
|
20
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
21
|
+
from tico.utils.utils import is_target_node
|
22
|
+
from tico.utils.validate_args_kwargs import CatArgs
|
23
|
+
|
24
|
+
|
25
|
+
@trace_graph_diff_on_pass
|
26
|
+
class MergeConsecutiveCat(PassBase):
|
27
|
+
"""
|
28
|
+
This pass merges consecutive `aten.cat` operators when they can be merged into single operator.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
graph_module = exported_program.graph_module
|
38
|
+
graph = graph_module.graph
|
39
|
+
modified = False
|
40
|
+
for cat in graph.nodes:
|
41
|
+
if not is_target_node(cat, ops.aten.cat):
|
42
|
+
continue
|
43
|
+
|
44
|
+
args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
|
45
|
+
inputs = args.tensors
|
46
|
+
dim = args.dim
|
47
|
+
|
48
|
+
new_inputs = []
|
49
|
+
for prev_cat in inputs:
|
50
|
+
new_inputs.append(prev_cat)
|
51
|
+
if not prev_cat.op == "call_function":
|
52
|
+
continue
|
53
|
+
|
54
|
+
if not prev_cat.target in ops.aten.cat:
|
55
|
+
continue
|
56
|
+
|
57
|
+
prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
|
58
|
+
prev_inputs = prev_args.tensors
|
59
|
+
prev_dim = prev_args.dim
|
60
|
+
|
61
|
+
if not prev_dim == dim:
|
62
|
+
continue
|
63
|
+
|
64
|
+
new_inputs.pop()
|
65
|
+
for prev_input in prev_inputs:
|
66
|
+
new_inputs.append(prev_input)
|
67
|
+
|
68
|
+
if len(new_inputs) > len(inputs):
|
69
|
+
cat.args = (new_inputs, dim)
|
70
|
+
|
71
|
+
modified = True
|
72
|
+
logger.debug(
|
73
|
+
f"Consecutive cat nodes before {cat.name} are merged into {cat.name}"
|
74
|
+
)
|
75
|
+
|
76
|
+
graph.eliminate_dead_code()
|
77
|
+
graph.lint()
|
78
|
+
graph_module.recompile()
|
79
|
+
|
80
|
+
return PassResult(modified)
|
tico/passes/ops.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
|
18
|
+
"""
|
19
|
+
This module contains Op lists used for finding the target Ops in passes.
|
20
|
+
The module is introduced to reduce duplicate codes.
|
21
|
+
It should be guaranteed that Ops in the same list have the same input/output signature.
|
22
|
+
"""
|
23
|
+
|
24
|
+
|
25
|
+
class AtenOps:
|
26
|
+
def __init__(self):
|
27
|
+
# In alphabetical order
|
28
|
+
self.add = [torch.ops.aten.add.Tensor]
|
29
|
+
self.alias = [torch.ops.aten.alias.default, torch.ops.aten.alias_copy.default]
|
30
|
+
self.cat = [torch.ops.aten.cat.default]
|
31
|
+
self.clamp = [torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]
|
32
|
+
self.clone = [torch.ops.aten.clone.default]
|
33
|
+
self.conv2d = [
|
34
|
+
torch.ops.aten.conv2d.default,
|
35
|
+
torch.ops.aten.conv2d.padding,
|
36
|
+
]
|
37
|
+
self.conv1d = [
|
38
|
+
torch.ops.aten.conv1d.default,
|
39
|
+
torch.ops.aten.conv1d.padding,
|
40
|
+
]
|
41
|
+
self.detach = [
|
42
|
+
torch.ops.aten.detach_.default,
|
43
|
+
torch.ops.aten.detach.default,
|
44
|
+
]
|
45
|
+
self.expand = [
|
46
|
+
torch.ops.aten.expand.default,
|
47
|
+
torch.ops.aten.expand_copy.default,
|
48
|
+
]
|
49
|
+
self.index_select = [torch.ops.aten.index_select.default]
|
50
|
+
self.mean = [torch.ops.aten.mean.dim]
|
51
|
+
self.mul_scalar = [torch.ops.aten.mul.Scalar]
|
52
|
+
self.mul_tensor = [torch.ops.aten.mul.Tensor]
|
53
|
+
self.permute = [torch.ops.aten.permute.default]
|
54
|
+
self.reshape = [torch.ops.aten.reshape.default]
|
55
|
+
self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
|
56
|
+
self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
|
57
|
+
self.softmax = [
|
58
|
+
torch.ops.aten._softmax.default,
|
59
|
+
torch.ops.aten._safe_softmax.default,
|
60
|
+
]
|
61
|
+
self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
|
62
|
+
self.to_copy = [
|
63
|
+
torch.ops.aten._to_copy.default,
|
64
|
+
torch.ops.aten.to.dtype,
|
65
|
+
torch.ops.aten.to.dtype_layout,
|
66
|
+
]
|
67
|
+
self.unsqueeze = [
|
68
|
+
torch.ops.aten.unsqueeze.default,
|
69
|
+
torch.ops.aten.unsqueeze_copy.default,
|
70
|
+
]
|
71
|
+
self.view = [
|
72
|
+
torch.ops.aten.view,
|
73
|
+
torch.ops.aten.view.default,
|
74
|
+
torch.ops.aten.view_copy.default,
|
75
|
+
]
|
76
|
+
|
77
|
+
|
78
|
+
aten = AtenOps()
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.utils import is_target_node
|
27
|
+
|
28
|
+
|
29
|
+
@trace_graph_diff_on_pass
|
30
|
+
class RemoveNop(PassBase):
|
31
|
+
"""
|
32
|
+
Let's remove noops by propagation.
|
33
|
+
"""
|
34
|
+
|
35
|
+
target_ops = (
|
36
|
+
[
|
37
|
+
torch.ops.prims.view_of.default,
|
38
|
+
]
|
39
|
+
+ ops.aten.alias
|
40
|
+
+ ops.aten.clone
|
41
|
+
+ ops.aten.detach
|
42
|
+
+ [torch.ops.aten.lift_fresh_copy.default]
|
43
|
+
)
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
graph_module = exported_program.graph_module
|
52
|
+
graph = graph_module.graph
|
53
|
+
modified = False
|
54
|
+
for node in graph.nodes:
|
55
|
+
if not is_target_node(node, RemoveNop.target_ops):
|
56
|
+
continue
|
57
|
+
|
58
|
+
# TODO Consider memory format
|
59
|
+
if node.target in ops.aten.clone and "memory_format" in node.kwargs:
|
60
|
+
if node.kwargs["memory_format"] not in [
|
61
|
+
torch.preserve_format,
|
62
|
+
# Converting non-contiguous layout to contiguous only updates
|
63
|
+
# strides of tensor. This is not visible on circle, so we can
|
64
|
+
# safely ignore this operation.
|
65
|
+
torch.contiguous_format,
|
66
|
+
]:
|
67
|
+
continue
|
68
|
+
|
69
|
+
assert len(node.args) == 1
|
70
|
+
|
71
|
+
src = node.args[0]
|
72
|
+
assert isinstance(src, torch.fx.Node)
|
73
|
+
|
74
|
+
with graph.inserting_after(node):
|
75
|
+
node.replace_all_uses_with(src, propagate_meta=False)
|
76
|
+
|
77
|
+
modified = True
|
78
|
+
logger.debug(f"{node.name} is replaced with {src}")
|
79
|
+
|
80
|
+
graph.eliminate_dead_code()
|
81
|
+
graph.lint()
|
82
|
+
graph_module.recompile()
|
83
|
+
|
84
|
+
return PassResult(modified)
|