tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.circle_mapping import extract_shape
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.validate_args_kwargs import IndexArgs, UpsampleNearest2DVecArgs
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class LowerToResizeNearestNeighbor(PassBase):
|
34
|
+
"""
|
35
|
+
This pass lowers `aten.index` and `aten.upsample_nearest2d.vec` to `circle_custom.resize_nearest_neighbor` when it is possible.
|
36
|
+
|
37
|
+
Until torch 2.7, `torch.nn.functional.interpolate` is converted to `aten.index` op.
|
38
|
+
|
39
|
+
[EXAMPLE]
|
40
|
+
class InterpolateDouble(torch.nn.Module):
|
41
|
+
def __init__(self):
|
42
|
+
super().__init__()
|
43
|
+
|
44
|
+
def forward(self, x):
|
45
|
+
return torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
46
|
+
|
47
|
+
def get_example_inputs(self):
|
48
|
+
return (torch.randn(1, 2, 3, 4),)
|
49
|
+
|
50
|
+
[EXPORTED GRAPH]
|
51
|
+
[constants]
|
52
|
+
_prop_tensor_constant0 = tensor([0, 0, 1, 1, 2, 2, 3, 3]
|
53
|
+
_prop_tensor_constant1 = tensor([[0], [0], [1], [1], [2], [2]])
|
54
|
+
|
55
|
+
[graph]
|
56
|
+
%_prop_tensor_constant0 : [num_users=1] = placeholder[target=_prop_tensor_constant0]
|
57
|
+
%_prop_tensor_constant1 : [num_users=1] = placeholder[target=_prop_tensor_constant1]
|
58
|
+
%x : [num_users=1] = placeholder[target=x]
|
59
|
+
%_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%x,), kwargs = {dtype: torch.float32})
|
60
|
+
%index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %_prop_tensor_constant1, %_prop_tensor_constant0]), kwargs = {})
|
61
|
+
%_to_copy_3 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%index,), kwargs = {dtype: torch.float32})
|
62
|
+
return (_to_copy_3,)
|
63
|
+
|
64
|
+
[BEFORE PASS]
|
65
|
+
input - aten.index - output
|
66
|
+
|
67
|
+
[AFTER PASS]
|
68
|
+
input - aten.permute(NCHW_to_NHWC) - circle_custom.resize_nearest_neighbor - aten.permute(NHWC_to_NCHW) - output
|
69
|
+
|
70
|
+
Since torch 2.8, `torch.nn.functional.interpolate` is converted to aten.upsample_nearest2d.vec` op.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(self):
|
74
|
+
super().__init__()
|
75
|
+
|
76
|
+
def convert_index_to_resize_nearest_neighbor(
|
77
|
+
self, exported_program, node
|
78
|
+
) -> Optional[torch.fx.Node]:
|
79
|
+
graph_module = exported_program.graph_module
|
80
|
+
graph = graph_module.graph
|
81
|
+
|
82
|
+
args = IndexArgs(*node.args, **node.kwargs)
|
83
|
+
input_tensor = args.input
|
84
|
+
indices = args.indices
|
85
|
+
|
86
|
+
# Only support 4-D tensor
|
87
|
+
if len(indices) != 4:
|
88
|
+
return None
|
89
|
+
# indices = [None, None, H index, W index]
|
90
|
+
N, C, H, W = indices
|
91
|
+
if N != None or C != None:
|
92
|
+
return None
|
93
|
+
if not isinstance(H, torch.fx.Node):
|
94
|
+
return None
|
95
|
+
if not isinstance(W, torch.fx.Node):
|
96
|
+
return None
|
97
|
+
constants_dict = exported_program.constants
|
98
|
+
if (H.name not in constants_dict) or (W.name not in constants_dict):
|
99
|
+
return None
|
100
|
+
H_index, W_index = constants_dict[H.name], constants_dict[W.name]
|
101
|
+
input_tensor_shape = extract_shape(input_tensor)
|
102
|
+
input_tensor_H, input_tensor_W = (
|
103
|
+
input_tensor_shape[2],
|
104
|
+
input_tensor_shape[3],
|
105
|
+
)
|
106
|
+
if H_index.size()[0] % input_tensor_H != 0:
|
107
|
+
return None
|
108
|
+
scale_factor = int(H_index.size()[0] / input_tensor_H)
|
109
|
+
# H and W should be resized with same ratio.
|
110
|
+
if scale_factor != W_index.size()[0] / input_tensor_W:
|
111
|
+
return None
|
112
|
+
expected_H_index = []
|
113
|
+
expected_W_index = []
|
114
|
+
# Please refer to above `_prop_tensor_constant1` constant in the example.
|
115
|
+
for i in range(input_tensor_H):
|
116
|
+
expected_H_index += [[i]] * scale_factor
|
117
|
+
# Please refer to above `_prop_tensor_constant0` constant in the example.
|
118
|
+
for i in range(input_tensor_W):
|
119
|
+
expected_W_index += [i] * scale_factor
|
120
|
+
if not torch.all(
|
121
|
+
torch.eq(H_index, torch.tensor(expected_H_index))
|
122
|
+
) or not torch.all(torch.eq(W_index, torch.tensor(expected_W_index))):
|
123
|
+
return None
|
124
|
+
expected_shape = [
|
125
|
+
input_tensor_shape[0],
|
126
|
+
input_tensor_shape[1],
|
127
|
+
len(expected_H_index),
|
128
|
+
len(expected_W_index),
|
129
|
+
]
|
130
|
+
assert expected_shape == list(extract_shape(node))
|
131
|
+
|
132
|
+
with graph.inserting_before(node):
|
133
|
+
nchw_to_nhwc = graph.call_function(
|
134
|
+
torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
|
135
|
+
)
|
136
|
+
resize_nearest_neighbor = graph.call_function(
|
137
|
+
torch.ops.circle_custom.resize_nearest_neighbor,
|
138
|
+
args=(nchw_to_nhwc, [len(expected_H_index), len(expected_W_index)]),
|
139
|
+
)
|
140
|
+
nhwc_to_nchw = graph.call_function(
|
141
|
+
torch.ops.aten.permute.default,
|
142
|
+
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
143
|
+
)
|
144
|
+
# Not set meta for propagating replacing node's meta.
|
145
|
+
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
146
|
+
|
147
|
+
return resize_nearest_neighbor
|
148
|
+
|
149
|
+
def convert_upsample_nearest2d_to_resize_nearest_neighbor(
|
150
|
+
self, exported_program, node
|
151
|
+
) -> Optional[torch.fx.Node]:
|
152
|
+
graph_module = exported_program.graph_module
|
153
|
+
graph = graph_module.graph
|
154
|
+
|
155
|
+
args = UpsampleNearest2DVecArgs(*node.args, **node.kwargs)
|
156
|
+
input_tensor = args.input
|
157
|
+
output_size = args.output_size
|
158
|
+
scale_factors = args.scale_factors
|
159
|
+
|
160
|
+
input_tensor_shape = extract_shape(input_tensor)
|
161
|
+
input_tensor_H, input_tensor_W = (
|
162
|
+
input_tensor_shape[2],
|
163
|
+
input_tensor_shape[3],
|
164
|
+
)
|
165
|
+
|
166
|
+
if output_size is not None:
|
167
|
+
raise NotYetSupportedError("output_size is not supported yet")
|
168
|
+
|
169
|
+
if scale_factors is None:
|
170
|
+
raise NotYetSupportedError("scale_factors is None")
|
171
|
+
# TODO Support output_size case. Currently only scale_factors case is supported.
|
172
|
+
|
173
|
+
assert (
|
174
|
+
isinstance(scale_factors[0], float)
|
175
|
+
and isinstance(scale_factors[1], float)
|
176
|
+
and scale_factors[0] > 0
|
177
|
+
and scale_factors[1] > 0
|
178
|
+
)
|
179
|
+
|
180
|
+
def close_enough(x, y, epsilon=1e-10):
|
181
|
+
return abs(x - y) < epsilon
|
182
|
+
|
183
|
+
expected_H = int(input_tensor_H * scale_factors[0])
|
184
|
+
if not close_enough(expected_H, input_tensor_H * scale_factors[0]):
|
185
|
+
raise NotYetSupportedError(
|
186
|
+
f"Cannot support input_tensor_H ({input_tensor_H}) with scaling factor ({scale_factors[0]})"
|
187
|
+
)
|
188
|
+
|
189
|
+
expected_W = int(input_tensor_W * scale_factors[1])
|
190
|
+
if not close_enough(expected_W, input_tensor_W * scale_factors[1]):
|
191
|
+
raise NotYetSupportedError(
|
192
|
+
f"Cannot support input_tensor_W ({input_tensor_W}) with scaling factor ({scale_factors[1]})"
|
193
|
+
)
|
194
|
+
|
195
|
+
with graph.inserting_before(node):
|
196
|
+
nchw_to_nhwc = graph.call_function(
|
197
|
+
torch.ops.aten.permute.default, args=(input_tensor, [0, 2, 3, 1])
|
198
|
+
)
|
199
|
+
resize_nearest_neighbor = graph.call_function(
|
200
|
+
torch.ops.circle_custom.resize_nearest_neighbor,
|
201
|
+
args=(nchw_to_nhwc, [expected_H, expected_W]),
|
202
|
+
)
|
203
|
+
nhwc_to_nchw = graph.call_function(
|
204
|
+
torch.ops.aten.permute.default,
|
205
|
+
args=(resize_nearest_neighbor, [0, 3, 1, 2]),
|
206
|
+
)
|
207
|
+
# Not set meta for propagating replacing node's meta.
|
208
|
+
node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True)
|
209
|
+
return resize_nearest_neighbor
|
210
|
+
|
211
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
212
|
+
logger = logging.getLogger(__name__)
|
213
|
+
|
214
|
+
modified = False
|
215
|
+
graph_module = exported_program.graph_module
|
216
|
+
graph = graph_module.graph
|
217
|
+
for node in graph.nodes:
|
218
|
+
if not node.op == "call_function":
|
219
|
+
continue
|
220
|
+
|
221
|
+
if node.target not in [
|
222
|
+
torch.ops.aten.index.Tensor,
|
223
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
224
|
+
]:
|
225
|
+
continue
|
226
|
+
|
227
|
+
resize_nearest_neighbor = None
|
228
|
+
if node.target == torch.ops.aten.index.Tensor:
|
229
|
+
resize_nearest_neighbor = self.convert_index_to_resize_nearest_neighbor(
|
230
|
+
exported_program, node
|
231
|
+
)
|
232
|
+
elif node.target == torch.ops.aten.upsample_nearest2d.vec:
|
233
|
+
resize_nearest_neighbor = (
|
234
|
+
self.convert_upsample_nearest2d_to_resize_nearest_neighbor(
|
235
|
+
exported_program, node
|
236
|
+
)
|
237
|
+
)
|
238
|
+
|
239
|
+
if resize_nearest_neighbor:
|
240
|
+
modified = True
|
241
|
+
logger.debug(
|
242
|
+
f"{node.name} is replaced with {resize_nearest_neighbor.name} operator"
|
243
|
+
)
|
244
|
+
|
245
|
+
graph.eliminate_dead_code()
|
246
|
+
graph.lint()
|
247
|
+
graph_module.recompile()
|
248
|
+
|
249
|
+
return PassResult(modified)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import extract_shape
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_const_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import SelectCopyIntArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_const_diff_on_pass
|
32
|
+
class LowerToSlice(PassBase):
|
33
|
+
"""
|
34
|
+
This pass lowers aten.ops.select/selct_copy.int to aten.ops.slice.
|
35
|
+
We support only when it is index in args, which is a constant tensor.
|
36
|
+
Since the index in node'args isn't constant tensor, we can't support converting the below op list yet.
|
37
|
+
- torch.ops.aten.index_select.default
|
38
|
+
- torch.ops.aten.embedding.default
|
39
|
+
- torch.ops.aten.index.Tensor
|
40
|
+
|
41
|
+
[before]
|
42
|
+
input (tensor, dim, *index)
|
43
|
+
|
|
44
|
+
select
|
45
|
+
|
|
46
|
+
output
|
47
|
+
|
48
|
+
[after]
|
49
|
+
|
50
|
+
input (tensor, dim, *index)
|
51
|
+
|
|
52
|
+
slice (input=tensor, dim=dim, start=index, end=index+1, step=1)
|
53
|
+
|
|
54
|
+
reshape (input=slice_copy, size=select_shape)
|
55
|
+
|
|
56
|
+
output
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(self):
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
63
|
+
logger = logging.getLogger(__name__)
|
64
|
+
|
65
|
+
graph_module = exported_program.graph_module
|
66
|
+
graph = graph_module.graph
|
67
|
+
modified = False
|
68
|
+
for node in graph.nodes:
|
69
|
+
if not node.op == "call_function":
|
70
|
+
continue
|
71
|
+
|
72
|
+
if not node.target in ops.aten.select:
|
73
|
+
continue
|
74
|
+
|
75
|
+
args = SelectCopyIntArgs(*node.args, **node.kwargs)
|
76
|
+
input = args.input
|
77
|
+
dim = args.dim
|
78
|
+
index = args.index
|
79
|
+
|
80
|
+
input_shape = extract_shape(input)
|
81
|
+
if dim < 0:
|
82
|
+
dim = dim % len(input_shape)
|
83
|
+
|
84
|
+
start = index
|
85
|
+
end = index + 1
|
86
|
+
step = 1
|
87
|
+
slice_copy_args = (input, dim, start, end, step)
|
88
|
+
|
89
|
+
with graph.inserting_after(node):
|
90
|
+
# slice
|
91
|
+
slice_node = graph.call_function(
|
92
|
+
torch.ops.aten.slice.Tensor, args=slice_copy_args
|
93
|
+
)
|
94
|
+
node_shape = extract_shape(node)
|
95
|
+
with graph.inserting_after(slice_node):
|
96
|
+
# reshape
|
97
|
+
reshape_args = (slice_node, list(node_shape))
|
98
|
+
reshape_node = graph.call_function(
|
99
|
+
torch.ops.aten.reshape.default, args=reshape_args
|
100
|
+
)
|
101
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=False)
|
102
|
+
|
103
|
+
modified = True
|
104
|
+
logger.debug(
|
105
|
+
f"{node.name} is replaced with {slice_node.name} and {reshape_node.name} operators"
|
106
|
+
)
|
107
|
+
|
108
|
+
graph.eliminate_dead_code()
|
109
|
+
graph.lint()
|
110
|
+
graph_module.recompile()
|
111
|
+
|
112
|
+
return PassResult(modified)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from torch.export import ExportedProgram
|
16
|
+
|
17
|
+
from tico.passes import ops
|
18
|
+
from tico.utils import logging
|
19
|
+
from tico.utils.passes import PassBase, PassResult
|
20
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
21
|
+
from tico.utils.validate_args_kwargs import CatArgs
|
22
|
+
|
23
|
+
|
24
|
+
@trace_graph_diff_on_pass
|
25
|
+
class MergeConsecutiveCat(PassBase):
|
26
|
+
"""
|
27
|
+
This pass merges consecutive `aten.cat` operators when they can be merged into single operator.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self):
|
31
|
+
super().__init__()
|
32
|
+
|
33
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
graph_module = exported_program.graph_module
|
37
|
+
graph = graph_module.graph
|
38
|
+
modified = False
|
39
|
+
for cat in graph.nodes:
|
40
|
+
if not cat.op == "call_function":
|
41
|
+
continue
|
42
|
+
|
43
|
+
if not cat.target in ops.aten.cat:
|
44
|
+
continue
|
45
|
+
|
46
|
+
args = CatArgs(*cat.args, **cat.kwargs) # type: ignore[arg-type]
|
47
|
+
inputs = args.tensors
|
48
|
+
dim = args.dim
|
49
|
+
|
50
|
+
new_inputs = []
|
51
|
+
for prev_cat in inputs:
|
52
|
+
new_inputs.append(prev_cat)
|
53
|
+
if not prev_cat.op == "call_function":
|
54
|
+
continue
|
55
|
+
|
56
|
+
if not prev_cat.target in ops.aten.cat:
|
57
|
+
continue
|
58
|
+
|
59
|
+
prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
|
60
|
+
prev_inputs = prev_args.tensors
|
61
|
+
prev_dim = prev_args.dim
|
62
|
+
|
63
|
+
if not prev_dim == dim:
|
64
|
+
continue
|
65
|
+
|
66
|
+
new_inputs.pop()
|
67
|
+
for prev_input in prev_inputs:
|
68
|
+
new_inputs.append(prev_input)
|
69
|
+
|
70
|
+
if len(new_inputs) > len(inputs):
|
71
|
+
cat.args = (new_inputs, dim)
|
72
|
+
|
73
|
+
modified = True
|
74
|
+
logger.debug(
|
75
|
+
f"Consecutive cat nodes before {cat.name} are merged into {cat.name}"
|
76
|
+
)
|
77
|
+
|
78
|
+
graph.eliminate_dead_code()
|
79
|
+
graph.lint()
|
80
|
+
graph_module.recompile()
|
81
|
+
|
82
|
+
return PassResult(modified)
|
tico/passes/ops.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
|
18
|
+
"""
|
19
|
+
This module contains Op lists used for finding the target Ops in passes.
|
20
|
+
The module is introduced to reduce duplicate codes.
|
21
|
+
It should be guaranteed that Ops in the same list have the same input/output signature.
|
22
|
+
"""
|
23
|
+
|
24
|
+
|
25
|
+
class AtenOps:
|
26
|
+
def __init__(self):
|
27
|
+
# In alphabetical order
|
28
|
+
self.add = [torch.ops.aten.add.Tensor]
|
29
|
+
self.alias = [torch.ops.aten.alias.default, torch.ops.aten.alias_copy.default]
|
30
|
+
self.cat = [torch.ops.aten.cat.default]
|
31
|
+
self.clamp = [torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]
|
32
|
+
self.clone = [torch.ops.aten.clone.default]
|
33
|
+
self.conv2d = [
|
34
|
+
torch.ops.aten.conv2d.default,
|
35
|
+
torch.ops.aten.conv2d.padding,
|
36
|
+
]
|
37
|
+
self.conv1d = [
|
38
|
+
torch.ops.aten.conv1d.default,
|
39
|
+
torch.ops.aten.conv1d.padding,
|
40
|
+
]
|
41
|
+
self.detach = [
|
42
|
+
torch.ops.aten.detach_.default,
|
43
|
+
torch.ops.aten.detach.default,
|
44
|
+
]
|
45
|
+
self.expand = [
|
46
|
+
torch.ops.aten.expand.default,
|
47
|
+
torch.ops.aten.expand_copy.default,
|
48
|
+
]
|
49
|
+
self.index_select = [torch.ops.aten.index_select.default]
|
50
|
+
self.mean = [torch.ops.aten.mean.dim]
|
51
|
+
self.mul_scalar = [torch.ops.aten.mul.Scalar]
|
52
|
+
self.mul_tensor = [torch.ops.aten.mul.Tensor]
|
53
|
+
self.permute = [torch.ops.aten.permute.default]
|
54
|
+
self.reshape = [torch.ops.aten.reshape.default]
|
55
|
+
self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
|
56
|
+
self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
|
57
|
+
self.softmax = [torch.ops.aten._softmax.default]
|
58
|
+
self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
|
59
|
+
self.to_copy = [
|
60
|
+
torch.ops.aten._to_copy.default,
|
61
|
+
torch.ops.aten.to.dtype,
|
62
|
+
torch.ops.aten.to.dtype_layout,
|
63
|
+
]
|
64
|
+
self.unsqueeze = [
|
65
|
+
torch.ops.aten.unsqueeze.default,
|
66
|
+
torch.ops.aten.unsqueeze_copy.default,
|
67
|
+
]
|
68
|
+
self.view = [
|
69
|
+
torch.ops.aten.view,
|
70
|
+
torch.ops.aten.view.default,
|
71
|
+
torch.ops.aten.view_copy.default,
|
72
|
+
]
|
73
|
+
|
74
|
+
|
75
|
+
aten = AtenOps()
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
|
27
|
+
|
28
|
+
@trace_graph_diff_on_pass
|
29
|
+
class RemoveNop(PassBase):
|
30
|
+
"""
|
31
|
+
Let's remove noops by propagation.
|
32
|
+
"""
|
33
|
+
|
34
|
+
target_ops = (
|
35
|
+
[
|
36
|
+
torch.ops.prims.view_of.default,
|
37
|
+
]
|
38
|
+
+ ops.aten.alias
|
39
|
+
+ ops.aten.clone
|
40
|
+
+ ops.aten.detach
|
41
|
+
+ [torch.ops.aten.lift_fresh_copy.default]
|
42
|
+
)
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
super().__init__()
|
46
|
+
|
47
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
48
|
+
logger = logging.getLogger(__name__)
|
49
|
+
|
50
|
+
graph_module = exported_program.graph_module
|
51
|
+
graph = graph_module.graph
|
52
|
+
modified = False
|
53
|
+
for node in graph.nodes:
|
54
|
+
if not node.op == "call_function":
|
55
|
+
continue
|
56
|
+
|
57
|
+
if not node.target in RemoveNop.target_ops:
|
58
|
+
continue
|
59
|
+
# TODO Consider memory format
|
60
|
+
if node.target in ops.aten.clone and "memory_format" in node.kwargs:
|
61
|
+
if node.kwargs["memory_format"] not in [
|
62
|
+
torch.preserve_format,
|
63
|
+
# Converting non-contiguous layout to contiguous only updates
|
64
|
+
# strides of tensor. This is not visible on circle, so we can
|
65
|
+
# safely ignore this operation.
|
66
|
+
torch.contiguous_format,
|
67
|
+
]:
|
68
|
+
continue
|
69
|
+
|
70
|
+
assert len(node.args) == 1
|
71
|
+
|
72
|
+
src = node.args[0]
|
73
|
+
assert isinstance(src, torch.fx.Node)
|
74
|
+
|
75
|
+
with graph.inserting_after(node):
|
76
|
+
node.replace_all_uses_with(src, propagate_meta=False)
|
77
|
+
|
78
|
+
modified = True
|
79
|
+
logger.debug(f"{node.name} is replaced with {src}")
|
80
|
+
|
81
|
+
graph.eliminate_dead_code()
|
82
|
+
graph.lint()
|
83
|
+
graph_module.recompile()
|
84
|
+
|
85
|
+
return PassResult(modified)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from torch.export import ExportedProgram
|
17
|
+
|
18
|
+
from tico.utils.passes import PassBase, PassResult
|
19
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
20
|
+
|
21
|
+
|
22
|
+
assert_node_targets = [
|
23
|
+
torch.ops.aten._assert_tensor_metadata.default,
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
@trace_graph_diff_on_pass
|
28
|
+
class RemoveRedundantAssertionNodes(PassBase):
|
29
|
+
"""
|
30
|
+
This removes redundant assertion nodes.
|
31
|
+
- `aten.assert_tensor_meta.default`
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
super().__init__()
|
36
|
+
|
37
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
38
|
+
graph_module = exported_program.graph_module
|
39
|
+
graph = graph_module.graph
|
40
|
+
modified = False
|
41
|
+
for node in graph.nodes:
|
42
|
+
if node.op == "call_function" and node.target in assert_node_targets:
|
43
|
+
graph.erase_node(node)
|
44
|
+
modified = True
|
45
|
+
|
46
|
+
graph.eliminate_dead_code()
|
47
|
+
graph.lint()
|
48
|
+
graph_module.recompile()
|
49
|
+
|
50
|
+
return PassResult(modified)
|