tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
|
26
|
+
from tico.utils.graph import add_placeholder
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.validate_args_kwargs import Conv2DArgs
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class DecomposeGroupedConv2d(PassBase):
|
34
|
+
"""
|
35
|
+
This pass decomposes grouped Conv2d operator as multiple Conv2d operator whose groups=1.
|
36
|
+
|
37
|
+
Grouped Conv2d denotes a Conv2d operator whose `groups` argument is not equal to input channels nor 1.
|
38
|
+
|
39
|
+
[before]
|
40
|
+
|
41
|
+
input weight bias
|
42
|
+
| | |
|
43
|
+
+-----------+-----------+
|
44
|
+
|
|
45
|
+
Conv2d (groups != IN_CHANNEL && groups != 1)
|
46
|
+
|
|
47
|
+
output
|
48
|
+
|
49
|
+
[after]
|
50
|
+
|
51
|
+
The below `slice` operators slice the input tensor, weight and bias along the channel axis by the number of `groups`.
|
52
|
+
In addition, the numbered input, weight and bias denotes sliced input tensor, weight and bias respectively.
|
53
|
+
|
54
|
+
input
|
55
|
+
| weight
|
56
|
+
slice | bias
|
57
|
+
| slice |
|
58
|
+
| | slice
|
59
|
+
| | |
|
60
|
+
+---------------------------+---------------------------+
|
61
|
+
| | | | |
|
62
|
+
| +---------------------------+---------------------------+
|
63
|
+
| | | | | | |
|
64
|
+
| | +---------------------------+---------------------------+
|
65
|
+
| | | | | | | | |
|
66
|
+
input_1 | | ... | | input_N | |
|
67
|
+
| weight_1 | | ... | | weight_N |
|
68
|
+
| | bias_1 | | ... | | bias_N
|
69
|
+
+---------+---------+ +---------+---------+ +---------+---------+
|
70
|
+
| | |
|
71
|
+
Conv2d_1 ... Conv2d_N
|
72
|
+
| | |
|
73
|
+
+---------------------------+---------------------------+
|
74
|
+
|
|
75
|
+
concat
|
76
|
+
|
|
77
|
+
output
|
78
|
+
"""
|
79
|
+
|
80
|
+
def __init__(self):
|
81
|
+
super().__init__()
|
82
|
+
|
83
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
84
|
+
logger = logging.getLogger(__name__)
|
85
|
+
|
86
|
+
gm = exported_program.graph_module
|
87
|
+
graph: torch.fx.Graph = gm.graph
|
88
|
+
modified = False
|
89
|
+
|
90
|
+
for node in graph.nodes:
|
91
|
+
if node.op != "call_function":
|
92
|
+
continue
|
93
|
+
if not node.target in ops.aten.conv2d:
|
94
|
+
continue
|
95
|
+
|
96
|
+
args = Conv2DArgs(*node.args)
|
97
|
+
input_ = args.input
|
98
|
+
weight = args.weight
|
99
|
+
bias = args.bias
|
100
|
+
stride = args.stride
|
101
|
+
padding = args.padding
|
102
|
+
dilation = args.dilation
|
103
|
+
groups = args.groups
|
104
|
+
|
105
|
+
input_shape = extract_shape(input_)
|
106
|
+
if not len(input_shape) == 4:
|
107
|
+
raise NotYetSupportedError(
|
108
|
+
f"Only support 4D input tensor: node's input shape: {input_shape}"
|
109
|
+
)
|
110
|
+
|
111
|
+
in_channels = input_shape[1]
|
112
|
+
if groups == 1 or groups == in_channels:
|
113
|
+
continue
|
114
|
+
assert (
|
115
|
+
in_channels % groups == 0
|
116
|
+
), f"in_channels should be divisible by groups: in_channels: {in_channels}, groups: {groups}"
|
117
|
+
|
118
|
+
output_shape = extract_shape(node)
|
119
|
+
assert len(output_shape) == 4, len(output_shape)
|
120
|
+
|
121
|
+
out_channels = output_shape[1]
|
122
|
+
assert (
|
123
|
+
out_channels % groups == 0
|
124
|
+
), f"out_channels should be divisible by groups: out_channels: {out_channels}, groups: {groups}"
|
125
|
+
|
126
|
+
weight_shape = extract_shape(weight)
|
127
|
+
assert len(weight_shape) == 4, len(weight_shape)
|
128
|
+
assert (
|
129
|
+
weight_shape[0] == out_channels
|
130
|
+
), f"weight shape[0]: {weight_shape[0]}, out channels: {out_channels}"
|
131
|
+
assert (
|
132
|
+
weight_shape[1] == in_channels // groups
|
133
|
+
), f"weight shape[1]: {weight_shape[1]}, in channels: {in_channels}"
|
134
|
+
|
135
|
+
if bias is not None:
|
136
|
+
bias_shape = extract_shape(bias)
|
137
|
+
assert (
|
138
|
+
bias_shape[0] == out_channels
|
139
|
+
), f"bias shape[0]: {bias_shape[0]}, out channels: {out_channels}"
|
140
|
+
else: # Make dummy bias tensor
|
141
|
+
bias = add_placeholder(
|
142
|
+
exported_program, torch.zeros(out_channels), "bias"
|
143
|
+
)
|
144
|
+
|
145
|
+
group_size = in_channels // groups
|
146
|
+
out_group_size = out_channels // groups
|
147
|
+
|
148
|
+
with gm.graph.inserting_before(node):
|
149
|
+
conv2d_op = None
|
150
|
+
if isinstance(padding, list) and all(
|
151
|
+
isinstance(element, int) for element in padding
|
152
|
+
):
|
153
|
+
conv2d_op = torch.ops.aten.conv2d.default
|
154
|
+
elif isinstance(padding, str):
|
155
|
+
conv2d_op = torch.ops.aten.conv2d.padding
|
156
|
+
else:
|
157
|
+
raise InvalidArgumentError(
|
158
|
+
f"Unsupported padding type: {padding}"
|
159
|
+
) # Unreachable to here
|
160
|
+
|
161
|
+
conv2d_tensors = []
|
162
|
+
for i in range(groups):
|
163
|
+
sliced_input = graph.call_function(
|
164
|
+
torch.ops.aten.slice.Tensor,
|
165
|
+
(input_, 1, group_size * i, group_size * (i + 1), 1),
|
166
|
+
)
|
167
|
+
sliced_weight = graph.call_function(
|
168
|
+
torch.ops.aten.slice.Tensor,
|
169
|
+
(weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
170
|
+
)
|
171
|
+
sliced_bias = graph.call_function(
|
172
|
+
torch.ops.aten.slice.Tensor,
|
173
|
+
(bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
174
|
+
)
|
175
|
+
conv2d_tensor = graph.call_function(
|
176
|
+
conv2d_op,
|
177
|
+
(
|
178
|
+
sliced_input,
|
179
|
+
sliced_weight,
|
180
|
+
sliced_bias,
|
181
|
+
stride,
|
182
|
+
padding,
|
183
|
+
dilation,
|
184
|
+
1,
|
185
|
+
),
|
186
|
+
)
|
187
|
+
conv2d_tensors.append(conv2d_tensor)
|
188
|
+
|
189
|
+
concat_output = graph.call_function(
|
190
|
+
torch.ops.aten.cat.default, (conv2d_tensors, 1)
|
191
|
+
)
|
192
|
+
|
193
|
+
node.replace_all_uses_with(concat_output, propagate_meta=True)
|
194
|
+
|
195
|
+
modified = True
|
196
|
+
logger.debug(
|
197
|
+
f"{node.name} is replaced with groups of conv2d: The number of groups: {groups}, groups size: {group_size}"
|
198
|
+
)
|
199
|
+
|
200
|
+
graph.eliminate_dead_code()
|
201
|
+
gm.recompile()
|
202
|
+
return PassResult(modified)
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Optional, TYPE_CHECKING
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from torch.export import ExportedProgram
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import enforce_type
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class DecomposeSliceScatter(PassBase):
|
33
|
+
"""
|
34
|
+
Let's decompose slice_scatter.default to cat.
|
35
|
+
|
36
|
+
slice_scatter with step=1 embeds src tensor to input tensor
|
37
|
+
We can replace it with (1) slicing input tensors and (2) concatenating all tensors
|
38
|
+
|
39
|
+
[1] When step = 1,
|
40
|
+
|
41
|
+
(1) Split input to input_0 and input_1 (either of them can be zero-size)
|
42
|
+
(2) Concatenate input_0, src, input_1
|
43
|
+
|
44
|
+
Before)
|
45
|
+
|
46
|
+
input src
|
47
|
+
| |
|
48
|
+
| |
|
49
|
+
| |
|
50
|
+
+--> slice_scatter <---+
|
51
|
+
|
52
|
+
After)
|
53
|
+
|
54
|
+
input
|
55
|
+
|-------------------------
|
56
|
+
| |
|
57
|
+
| |
|
58
|
+
| |
|
59
|
+
slice_copy slice_copy
|
60
|
+
| |
|
61
|
+
| |
|
62
|
+
| |
|
63
|
+
slice_0* src slice_1*
|
64
|
+
| | |
|
65
|
+
| | |
|
66
|
+
| | |
|
67
|
+
+---------> cat <---------+
|
68
|
+
|
69
|
+
*Either of slice_0 or slice_1 could be empty. Then it's ignored.
|
70
|
+
|
71
|
+
[2] When step > 1, not supported yet. (TBD)
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self):
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
78
|
+
logger = logging.getLogger(__name__)
|
79
|
+
|
80
|
+
graph_module = exported_program.graph_module
|
81
|
+
graph: torch.fx.Graph = graph_module.graph
|
82
|
+
modified = False
|
83
|
+
|
84
|
+
for node in graph.nodes:
|
85
|
+
if node.op != "call_function":
|
86
|
+
continue
|
87
|
+
if node.target != torch.ops.aten.slice_scatter.default:
|
88
|
+
continue
|
89
|
+
|
90
|
+
@enforce_type
|
91
|
+
@dataclass
|
92
|
+
class Args:
|
93
|
+
"""
|
94
|
+
input (Tensor) the input tensor.
|
95
|
+
src (Tensor) The tensor to embed into input
|
96
|
+
dim (int) the dimension to insert the slice into
|
97
|
+
start (Optional[int]) the start index of where to insert the slice
|
98
|
+
end (Optional[int]) the end index of where to insert the slice
|
99
|
+
step (int) the how many elements to skip in
|
100
|
+
"""
|
101
|
+
|
102
|
+
input: torch.fx.Node
|
103
|
+
src: torch.fx.Node
|
104
|
+
dim: int = 0
|
105
|
+
start: Optional[int] = None
|
106
|
+
end: Optional[int] = None
|
107
|
+
step: int = 1
|
108
|
+
|
109
|
+
args = Args(*node.args, **node.kwargs) # type: ignore[arg-type]
|
110
|
+
|
111
|
+
input = args.input
|
112
|
+
src = args.src
|
113
|
+
dim = args.dim
|
114
|
+
s = args.start
|
115
|
+
e = args.end
|
116
|
+
step = args.step
|
117
|
+
|
118
|
+
# TODO Support step > 1 cases
|
119
|
+
if step > 1:
|
120
|
+
raise RuntimeError(
|
121
|
+
f"slice_scatter with step > 1 is not yet supported. Node: {node}"
|
122
|
+
)
|
123
|
+
|
124
|
+
start: int = 0 if s is None else s
|
125
|
+
end: int = (
|
126
|
+
extract_shape(src)[dim]
|
127
|
+
if e is None
|
128
|
+
else min(extract_shape(src)[dim], e)
|
129
|
+
)
|
130
|
+
|
131
|
+
with graph.inserting_before(node):
|
132
|
+
slices = []
|
133
|
+
|
134
|
+
if 0 < start:
|
135
|
+
slice_0 = graph.call_function(
|
136
|
+
torch.ops.aten.slice_copy.Tensor,
|
137
|
+
args=(input, dim, 0, start, 1),
|
138
|
+
)
|
139
|
+
slices.append(slice_0)
|
140
|
+
|
141
|
+
slices.append(src)
|
142
|
+
|
143
|
+
if start + end < extract_shape(input)[dim]:
|
144
|
+
slice_1 = graph.call_function(
|
145
|
+
torch.ops.aten.slice_copy.Tensor,
|
146
|
+
args=(
|
147
|
+
input,
|
148
|
+
dim,
|
149
|
+
start + end,
|
150
|
+
extract_shape(input)[dim],
|
151
|
+
1,
|
152
|
+
),
|
153
|
+
)
|
154
|
+
slices.append(slice_1)
|
155
|
+
|
156
|
+
concat = graph.call_function(
|
157
|
+
torch.ops.aten.cat.default, args=(slices, dim)
|
158
|
+
)
|
159
|
+
# Not set meta for propagating replacing node's meta.
|
160
|
+
node.replace_all_uses_with(concat, propagate_meta=True)
|
161
|
+
|
162
|
+
modified = True
|
163
|
+
logger.debug(f"{node.name} is replaced with slice_copy + concat")
|
164
|
+
|
165
|
+
graph.eliminate_dead_code()
|
166
|
+
graph_module.recompile()
|
167
|
+
return PassResult(modified)
|
@@ -0,0 +1,121 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
from torch.utils import _pytree as pytree
|
22
|
+
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
|
27
|
+
|
28
|
+
def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
|
29
|
+
"""
|
30
|
+
This extracts dtype kwargs to node's output direction
|
31
|
+
|
32
|
+
So, op(..., dtype = X) is converted to op(...).to(X)
|
33
|
+
|
34
|
+
Return true if modified
|
35
|
+
|
36
|
+
NOTE
|
37
|
+
|
38
|
+
[1] This function always returns true. Return value is introduced for extension
|
39
|
+
[2] This conversion is not safe for some Ops whose inputs should also be casted to X (ex: Mean).
|
40
|
+
|
41
|
+
"""
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
node_kwargs = node.kwargs
|
45
|
+
# Remove "dtype" from node's kwargs
|
46
|
+
new_kwargs = {}
|
47
|
+
for k, v in node_kwargs.items():
|
48
|
+
if k == "dtype":
|
49
|
+
continue
|
50
|
+
new_kwargs[k] = v
|
51
|
+
node.kwargs = new_kwargs
|
52
|
+
# Create new val for node
|
53
|
+
# `node.target()` needs only `Tensor` for its arguments. Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`.
|
54
|
+
args, kwargs = pytree.tree_map_only(
|
55
|
+
torch.fx.Node, lambda x: x.meta["val"], (node.args, node.kwargs)
|
56
|
+
)
|
57
|
+
new_val = node.target(*args, **kwargs) # type: ignore[operator]
|
58
|
+
# Set args, kwargs of `to_copy`
|
59
|
+
to_args = (node,)
|
60
|
+
to_kwargs = {"dtype": node_kwargs["dtype"]}
|
61
|
+
with graph.inserting_after(node):
|
62
|
+
to_copy = graph.call_function(torch.ops.aten._to_copy.default, (), {})
|
63
|
+
node.replace_all_uses_with(to_copy, propagate_meta=True)
|
64
|
+
# Q) Why lazy-update args, kwargs of the `to_copy`?
|
65
|
+
# A) `replace_all_uses_with` replace all the uses of `node`. If `to_copy` args is set to
|
66
|
+
# (node, ) before `replace_all_uses_with`, the function would even replace the args of
|
67
|
+
# `to_copy` with `to_copy`.
|
68
|
+
to_copy.args = to_args
|
69
|
+
to_copy.kwargs = to_kwargs
|
70
|
+
# Update meta["val"] to change dtype
|
71
|
+
node.meta["val"] = new_val
|
72
|
+
|
73
|
+
logger.debug(f"{node.name}'s dtype kwargs is extracted into {to_copy.name}")
|
74
|
+
|
75
|
+
return True
|
76
|
+
|
77
|
+
|
78
|
+
@trace_graph_diff_on_pass
|
79
|
+
class ExtractDtypeKwargsPass(PassBase):
|
80
|
+
"""
|
81
|
+
This pass extracts "dtype" keyword argument from nodes.
|
82
|
+
|
83
|
+
Sometimes, torch api receives "dtype" keyword argument.
|
84
|
+
|
85
|
+
E.g. x_bool = torch.full_like(x, 0, dtype=torch.bool)
|
86
|
+
|
87
|
+
But, this argument makes circle build logic complicated because many operators has
|
88
|
+
same type with their inputs'.
|
89
|
+
|
90
|
+
So, this pass changes `op(dtype)` to `op + to(dtype)`.
|
91
|
+
|
92
|
+
NOTE
|
93
|
+
|
94
|
+
[1] There are some ops that are natural to have "dtype" kwargs. The pass is not applied to those ops.
|
95
|
+
[2] If node.kwargs["dtype"] is redundant `op(dtype).dtype == op().dtype`, the pass is not applied.
|
96
|
+
|
97
|
+
"""
|
98
|
+
|
99
|
+
def __init__(self):
|
100
|
+
super().__init__()
|
101
|
+
# List of Ops whose "dtype" kwargs is extracted
|
102
|
+
self.target_ops = dict()
|
103
|
+
self.target_ops[torch.ops.aten.full_like.default] = _extract_to_output
|
104
|
+
|
105
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
106
|
+
graph_module = exported_program.graph_module
|
107
|
+
graph: torch.fx.Graph = graph_module.graph
|
108
|
+
modified = False
|
109
|
+
for node in graph.nodes:
|
110
|
+
if not node.op == "call_function" or node.target not in self.target_ops:
|
111
|
+
continue
|
112
|
+
if "dtype" not in node.kwargs:
|
113
|
+
continue
|
114
|
+
|
115
|
+
modified |= self.target_ops[node.target](node, graph)
|
116
|
+
|
117
|
+
graph.eliminate_dead_code()
|
118
|
+
graph.lint()
|
119
|
+
graph_module.recompile()
|
120
|
+
|
121
|
+
return PassResult(modified)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from torch.export import ExportedProgram
|
16
|
+
|
17
|
+
from tico.utils import logging
|
18
|
+
from tico.utils.passes import PassBase, PassResult
|
19
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
20
|
+
from tico.utils.utils import set_new_meta_val
|
21
|
+
|
22
|
+
|
23
|
+
@trace_graph_diff_on_pass
|
24
|
+
class FillMetaVal(PassBase):
|
25
|
+
"""
|
26
|
+
Let's set new meta['val'] for nodes which don't have meta['val']
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
super().__init__()
|
31
|
+
|
32
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
graph_module = exported_program.graph_module
|
36
|
+
graph = graph_module.graph
|
37
|
+
modified = False
|
38
|
+
# To make sure graph is topologically sorted
|
39
|
+
graph.lint()
|
40
|
+
for node in graph.nodes:
|
41
|
+
if not node.op == "call_function":
|
42
|
+
continue
|
43
|
+
|
44
|
+
if hasattr(node, "meta") and "val" in node.meta:
|
45
|
+
continue
|
46
|
+
|
47
|
+
set_new_meta_val(node)
|
48
|
+
|
49
|
+
modified = True
|
50
|
+
|
51
|
+
logger.debug(f"{node.name} has new meta values.")
|
52
|
+
|
53
|
+
graph.eliminate_dead_code()
|
54
|
+
graph.lint()
|
55
|
+
graph_module.recompile()
|
56
|
+
|
57
|
+
return PassResult(modified)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
from torch.utils import _pytree as pytree
|
22
|
+
|
23
|
+
from tico.passes import ops
|
24
|
+
from tico.serialize.circle_mapping import extract_shape
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
|
29
|
+
|
30
|
+
@trace_graph_diff_on_pass
|
31
|
+
class FuseRedundantReshapeToMean(PassBase):
|
32
|
+
"""
|
33
|
+
This pass removes redundant `aten.reshape` operators that can be fused to `aten.mean` with `keep_dims`.
|
34
|
+
|
35
|
+
Shape(aten.reshape(aten.mean(input))) == Shape(aten.mean(input, keep_dims=True))
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self):
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
graph_module = exported_program.graph_module
|
45
|
+
graph = graph_module.graph
|
46
|
+
modified = False
|
47
|
+
for node in graph.nodes:
|
48
|
+
if not node.op == "call_function":
|
49
|
+
continue
|
50
|
+
|
51
|
+
if node.target != torch.ops.aten.mean.dim:
|
52
|
+
continue
|
53
|
+
|
54
|
+
# If mean is being used in other nodes, do not fuse it.
|
55
|
+
if len(node.users) != 1:
|
56
|
+
continue
|
57
|
+
|
58
|
+
user_node = next(iter(node.users))
|
59
|
+
if user_node.target not in ops.aten.reshape:
|
60
|
+
continue
|
61
|
+
|
62
|
+
mean_args, mean_kwargs = pytree.tree_map_only(
|
63
|
+
torch.fx.Node,
|
64
|
+
lambda n: n.meta["val"],
|
65
|
+
(node.args, node.kwargs),
|
66
|
+
)
|
67
|
+
# Signature of aten.mean.dim is as follows.
|
68
|
+
# mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
69
|
+
# `keepdim` in `node.kwargs` is moved to `node.args` in `run_decompositions`.
|
70
|
+
# `dtype` in `node.kwargs` is not moved
|
71
|
+
assert len(mean_args) == 3 or len(mean_args) == 2 # keepdim exists or not
|
72
|
+
assert len(mean_kwargs) <= 1 # dtype exists or not
|
73
|
+
fused_mean_args = mean_args
|
74
|
+
keep_dims = True
|
75
|
+
if len(mean_args) == 2:
|
76
|
+
fused_mean_args += (keep_dims,)
|
77
|
+
|
78
|
+
fused_val = node.target(*fused_mean_args, **mean_kwargs)
|
79
|
+
|
80
|
+
# Check if both shapes are same
|
81
|
+
# 1. Shape(aten.reshape(aten.mean))
|
82
|
+
# 2. Shape(aten.mean(keep_dims=True))
|
83
|
+
if fused_val.size() != extract_shape(user_node):
|
84
|
+
continue
|
85
|
+
|
86
|
+
# update args
|
87
|
+
if len(mean_args) == 2:
|
88
|
+
updated_args = node.args + (keep_dims,)
|
89
|
+
elif len(mean_args) == 3:
|
90
|
+
updated_args = node.args
|
91
|
+
node.args = updated_args
|
92
|
+
node.meta["val"] = fused_val
|
93
|
+
user_node.replace_all_uses_with(node, propagate_meta=False)
|
94
|
+
|
95
|
+
modified = True
|
96
|
+
logger.debug(f"{user_node.name} is replaced with {node.name}")
|
97
|
+
|
98
|
+
graph.eliminate_dead_code()
|
99
|
+
graph.lint()
|
100
|
+
graph_module.recompile()
|
101
|
+
|
102
|
+
return PassResult(modified)
|